the_one_ps.py 63.8 KB
Newer Older
Z
ziyoujiyi 已提交
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Z
ziyoujiyi 已提交
2
#
Z
ziyoujiyi 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Z
ziyoujiyi 已提交
6
#
Z
ziyoujiyi 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
Z
ziyoujiyi 已提交
8
#
Z
ziyoujiyi 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Z
ziyoujiyi 已提交
14 15 16 17

import warnings

import os
W
wangguanqun 已提交
18
import paddle
Z
ziyoujiyi 已提交
19
from paddle.distributed import fleet
W
wangguanqun 已提交
20
from paddle.framework import core
21
from paddle.distributed.ps.utils.public import *  # noqa: F403
W
wangguanqun 已提交
22
from paddle.static import Program, CompiledProgram, Executor, ParallelExecutor
W
wangguanqun 已提交
23
from paddle.distributed.fleet.runtime.runtime_base import RuntimeBase
24 25 26
from paddle.distributed.fleet.base.private_helper_function import (
    wait_server_ready,
)
Z
ziyoujiyi 已提交
27
from paddle.distributed.fleet.proto import the_one_ps_pb2
W
wangguanqun 已提交
28
from paddle.distributed.communicator import Communicator, HeterClient
Z
ziyoujiyi 已提交
29
from google.protobuf import text_format
30
from paddle.distributed.ps.coordinator import Coordinator
Z
ziyoujiyi 已提交
31

Z
ziyoujiyi 已提交
32
__all__ = [
33 34 35 36 37 38
    'Table',
    'SparseTable',
    'GeoSparseTable',
    'BarrierTable',
    'TensorTable',
    'DenseTable',
Z
ziyoujiyi 已提交
39
]
Z
ziyoujiyi 已提交
40 41


W
wangguanqun 已提交
42 43 44 45
def get_program_by_id(context, program_id):
    programs = context["origin_main_programs"]
    for i, program in enumerate(programs):
        if id(program) == program_id:
46 47
            return program, context["origin_startup_programs"][i], i
    return None, None, None
W
wangguanqun 已提交
48 49 50


def parse_table_class(varname, program_id, context):
51
    main_program, startup_program, idx = get_program_by_id(context, program_id)
W
wangguanqun 已提交
52
    for op in main_program.global_block().ops:
Z
ziyoujiyi 已提交
53 54 55 56 57
        if not is_distributed_sparse_op(op) and not is_sparse_op(op):
            continue

        param_name = op.input("W")[0]

58 59 60 61 62
        if (
            param_name == varname
            and op.type == "lookup_table"
            or op.type == "lookup_table_v2"
        ):
Z
ziyoujiyi 已提交
63 64 65 66 67 68
            if op.has_attr('table_class') and op.attr("table_class") != "none":
                return op.attr('table_class')
            else:
                return "MemorySparseTable"


Z
ziyoujiyi 已提交
69
def check_embedding_dim(accessor_proto, varname, program_id, context):
70
    main_program, startup_program, idx = get_program_by_id(context, program_id)
Z
ziyoujiyi 已提交
71
    embedding_dim = 0
W
wangguanqun 已提交
72
    for var in main_program.list_vars():
Z
ziyoujiyi 已提交
73 74
        if var.name == varname:
            embedding_dim = var.shape[1]
75 76 77 78 79
            print(
                'new var: {}, {}, {}'.format(
                    var, embedding_dim, accessor_proto.fea_dim
                )
            )
Z
ziyoujiyi 已提交
80
            break
81

Z
ziyoujiyi 已提交
82
    fea_dim = accessor_proto.fea_dim
83 84 85
    if accessor_proto.accessor_class == "SparseAccessor":
        if fea_dim != embedding_dim + 2:
            raise ValueError(
86 87 88 89
                "The fea_dim is wrong, it will be sparse_embedding_dim + 2: {}, but got {}".format(
                    embedding_dim + 2, fea_dim
                )
            )
90 91 92
    else:
        if fea_dim != embedding_dim:
            raise ValueError(
93 94 95 96
                "The fea_dim is wrong, it will be sparse_embedding_dim: {}, but got {}".format(
                    embedding_dim, fea_dim
                )
            )
97

Z
ziyoujiyi 已提交
98
    embedx_dim = accessor_proto.embedx_dim
99 100 101
    if accessor_proto.accessor_class == "SparseAccessor":
        if embedx_dim != embedding_dim - 1:
            raise ValueError(
102 103 104 105
                "The embedx_dim is wrong, it will be sparse_embedding_dim - 1: {}, but got {}".format(
                    embedding_dim - 1, embedx_dim
                )
            )
106 107 108
    else:
        if embedx_dim != embedding_dim - 3:
            raise ValueError(
109 110 111 112
                "The embedx_dim is wrong, it will be sparse_embedding_dim - 3: {}, but got {}".format(
                    embedding_dim - 3, embedx_dim
                )
            )
Z
ziyoujiyi 已提交
113 114


Z
ziyoujiyi 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128
class Service:
    def __init__(self):
        pass

    def _set(self, service_proto):
        service_proto.server_class = "BrpcPsServer"
        service_proto.client_class = "BrpcPsClient"
        service_proto.service_class = "BrpcPsService"
        service_proto.start_server_port = 0
        service_proto.server_thread_num = 12


class GpuService(Service):
    def __init__(self):
129
        super().__init__()
Z
ziyoujiyi 已提交
130 131 132 133 134 135

    def _set(self, service_proto):
        service_proto.server_class = 'PsLocalServer'
        service_proto.client_class = 'PsLocalClient'


Z
ziyoujiyi 已提交
136 137 138 139
class Accessor:
    def __init__(self):
        self.accessor_class = ""
        self.optimizer = None
Z
ziyoujiyi 已提交
140 141
        self.feature_dim = 0
        self.embedding_dim = 0
Z
ziyoujiyi 已提交
142

Z
ziyoujiyi 已提交
143
    # TableAccessorParameter accessor
144 145 146
    def _set(
        self, accessor_proto, varname, program_id, context, common_accessor
    ):
147
        main_program, startup_program, idx = get_program_by_id(
148 149
            context, program_id
        )
Z
ziyoujiyi 已提交
150 151 152 153 154
        embedding_dim = 0
        for var in main_program.list_vars():
            if var.name == varname:
                embedding_dim = var.shape[1]
                break
Z
ziyoujiyi 已提交
155

Z
ziyoujiyi 已提交
156
        if not accessor_proto.HasField("accessor_class"):
157
            # DownpourSparseValueAccessor
158
            if context['use_ps_gpu']:
159
                accessor_proto.accessor_class = "CtrDymfAccessor"
160 161
            else:
                accessor_proto.accessor_class = "SparseAccessor"
Z
ziyoujiyi 已提交
162
        if not accessor_proto.HasField("fea_dim"):
163 164 165 166
            if accessor_proto.accessor_class == "SparseAccessor":
                accessor_proto.fea_dim = embedding_dim + 2
            else:
                accessor_proto.fea_dim = embedding_dim
Z
ziyoujiyi 已提交
167
        if not accessor_proto.HasField("embedx_dim"):
168 169 170 171
            if accessor_proto.accessor_class == "SparseAccessor":
                accessor_proto.embedx_dim = embedding_dim - 1
            else:
                accessor_proto.embedx_dim = embedding_dim - 3
Z
ziyoujiyi 已提交
172 173 174
        if not accessor_proto.HasField("embedx_threshold"):
            accessor_proto.embedx_threshold = 0

D
danleifeng 已提交
175 176 177 178 179 180
        graph_sgd_param = accessor_proto.graph_sgd_param
        if not graph_sgd_param.HasField("nodeid_slot"):
            graph_sgd_param.nodeid_slot = 9008
        if not graph_sgd_param.HasField("feature_learning_rate"):
            graph_sgd_param.feature_learning_rate = 0.05

Z
ziyoujiyi 已提交
181
        ctr_accessor_param = accessor_proto.ctr_accessor_param
182 183
        if accessor_proto.embedx_dim == 0:
            ctr_accessor_param.zero_init = False
Z
ziyoujiyi 已提交
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
        if not ctr_accessor_param.HasField("nonclk_coeff"):
            ctr_accessor_param.nonclk_coeff = 0.1
        if not ctr_accessor_param.HasField("click_coeff"):
            ctr_accessor_param.click_coeff = 1.0
        if not ctr_accessor_param.HasField("base_threshold"):
            ctr_accessor_param.base_threshold = 0
        if not ctr_accessor_param.HasField("delta_threshold"):
            ctr_accessor_param.delta_threshold = 0
        if not ctr_accessor_param.HasField("delta_keep_days"):
            ctr_accessor_param.delta_keep_days = 16
        if not ctr_accessor_param.HasField("show_click_decay_rate"):
            ctr_accessor_param.show_click_decay_rate = 1
        if not ctr_accessor_param.HasField("delete_threshold"):
            ctr_accessor_param.delete_threshold = 0
        if not ctr_accessor_param.HasField("delete_after_unseen_days"):
            ctr_accessor_param.delete_after_unseen_days = 30
        if not ctr_accessor_param.HasField("ssd_unseenday_threshold"):
            ctr_accessor_param.ssd_unseenday_threshold = 1

        for sgd_param in [
204 205
            accessor_proto.embed_sgd_param,
            accessor_proto.embedx_sgd_param,
Z
ziyoujiyi 已提交
206 207
        ]:
            if not sgd_param.HasField("name"):
208 209 210 211
                if common_accessor.accessor_class == "sgd":
                    sgd_param.name = "SparseNaiveSGDRule"
                if common_accessor.accessor_class == "adam":
                    sgd_param.name = "SparseAdamSGDRule"
Z
ziyoujiyi 已提交
212 213
                else:  # for fl-ps, because geo accessor is 'sum'
                    sgd_param.name = "SparseAdamSGDRule"
214

215 216 217 218
            if (
                sgd_param.name == "SparseAdaGradSGDRule"
                or sgd_param.name == "StdAdaGradSGDRule"
            ):
Z
ziyoujiyi 已提交
219 220 221 222 223 224 225 226
                if not sgd_param.adagrad.HasField("learning_rate"):
                    sgd_param.adagrad.learning_rate = 0.05
                if not sgd_param.adagrad.HasField("initial_g2sum"):
                    sgd_param.adagrad.initial_g2sum = 3.0
                if not sgd_param.adagrad.HasField("initial_range"):
                    sgd_param.adagrad.initial_range = 0.0001
                if len(sgd_param.adagrad.weight_bounds) == 0:
                    sgd_param.adagrad.weight_bounds.extend([-10.0, 10.0])
227

Z
ziyoujiyi 已提交
228 229
            if sgd_param.name == "SparseNaiveSGDRule":
                if not sgd_param.naive.HasField("learning_rate"):
230 231 232
                    learning_rate = common_accessor.initializers[-1].split("&")[
                        1
                    ]
233
                    sgd_param.naive.learning_rate = float(learning_rate)
Z
ziyoujiyi 已提交
234
                if not sgd_param.naive.HasField("initial_range"):
235 236 237
                    initial_range = common_accessor.initializers[0].split("&")[
                        -1
                    ]
238
                    sgd_param.naive.initial_range = float(initial_range)
Z
ziyoujiyi 已提交
239 240
                if len(sgd_param.naive.weight_bounds) == 0:
                    sgd_param.naive.weight_bounds.extend([-10.0, 10.0])
241

242 243 244 245
            if (
                sgd_param.name == "SparseAdamSGDRule"
                or sgd_param.name == "SparseSharedAdamSGDRule"
            ):
Z
ziyoujiyi 已提交
246
                if not sgd_param.adam.HasField("learning_rate"):
247 248 249
                    learning_rate = common_accessor.initializers[-1].split("&")[
                        1
                    ]
250
                    sgd_param.adam.learning_rate = float(learning_rate)
Z
ziyoujiyi 已提交
251
                if not sgd_param.adam.HasField("initial_range"):
252 253 254
                    initial_range = common_accessor.initializers[0].split("&")[
                        -1
                    ]
255 256 257
                    sgd_param.adam.initial_range = float(initial_range)

                attr_list = [x.split("&") for x in common_accessor.attrs]
258 259 260 261
                if (
                    not sgd_param.adam.HasField("beta1_decay_rate")
                    and common_accessor.accessor_class == "adam"
                ):
262 263
                    sgd_param.adam.beta1_decay_rate = float(attr_list[0][1])
                else:
Z
ziyoujiyi 已提交
264
                    sgd_param.adam.beta1_decay_rate = 0.9
265 266 267 268
                if (
                    not sgd_param.adam.HasField("beta2_decay_rate")
                    and common_accessor.accessor_class == "adam"
                ):
269 270
                    sgd_param.adam.beta2_decay_rate = float(attr_list[1][1])
                else:
Z
ziyoujiyi 已提交
271
                    sgd_param.adam.beta2_decay_rate = 0.999
272 273 274 275
                if (
                    not sgd_param.adam.HasField("ada_epsilon")
                    and common_accessor.accessor_class == "adam"
                ):
276 277
                    sgd_param.adam.ada_epsilon = float(attr_list[2][1])
                else:
Z
ziyoujiyi 已提交
278 279 280 281 282 283
                    sgd_param.adam.ada_epsilon = 1e-08
                if len(sgd_param.adam.weight_bounds) == 0:
                    sgd_param.adam.weight_bounds.extend([-10.0, 10.0])


class CommonAccessor(Accessor):
Z
ziyoujiyi 已提交
284
    def __init__(self):
285
        super().__init__()
Z
ziyoujiyi 已提交
286 287
        self.table_name = ''
        self.entry = 'none'
Z
ziyoujiyi 已提交
288 289 290 291
        self.attrs = []
        self.params = []
        self.dims = []
        self.trainer_num = 0
Z
ziyoujiyi 已提交
292
        self.sync = False
Z
ziyoujiyi 已提交
293 294 295 296 297 298 299 300 301
        self.initializers = []
        self.opt_input_map = {}
        self.opt_attr_map = {}
        self.opt_init_map = {}
        self.define_optimize_map()

    def define_optimize_map(self):
        opt_input_map = {}
        opt_input_map["sgd"] = [("Param", None), ("LearningRate", 1)]
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
        opt_input_map["adam"] = [
            ("Param", None),
            ("Moment1", None),
            ("Moment2", None),
            ("Beta1Pow", 1),
            ("Beta2Pow", 1),
            ("LearningRate", 1),
        ]
        opt_input_map["adam_d2sum"] = [
            ("Param", None),
            ("D2Sum", None),
            ("G2Sum", None),
            ("Moment", None),
            ("MomentDecayRate", 1),
            ("AdaDecayRate", 1),
            ("AdaEpsilon", 1),
            ("LearningRate", 1),
        ]
Z
ziyoujiyi 已提交
320
        opt_input_map["sum"] = [("Param", None)]
321 322 323 324 325
        opt_input_map["naive_adagrad"] = [
            ("Param", None),
            ("G2Sum", 1),
            ("LearningRate", 1),
        ]
W
wangguanqun 已提交
326
        opt_input_map["summary"] = [("Param", None), ("SummaryDecayRate", 1)]
Z
ziyoujiyi 已提交
327 328 329 330 331

        opt_attr_map = {}
        opt_attr_map["sgd"] = []
        opt_attr_map["sum"] = []
        opt_attr_map["naive_adagrad"] = []
332 333 334 335 336 337 338 339 340 341
        opt_attr_map["adam"] = [
            ("beta1", "f"),
            ("beta2", "f"),
            ("epsilon", "f"),
        ]
        opt_attr_map["adam_d2sum"] = [
            ("beta1", "f"),
            ("beta2", "f"),
            ("epsilon", "f"),
        ]
342
        opt_attr_map["summary"] = [("summary_decay_rate", "f")]
Z
ziyoujiyi 已提交
343 344 345 346 347 348 349 350 351 352 353

        opt_init_map = {}
        opt_init_map["gaussian_random"] = ["seed", "mean", "std"]
        opt_init_map["fill_constant"] = ["value"]
        opt_init_map["uniform_random"] = ["seed", "min", "max"]
        opt_init_map["truncated_gaussian_random"] = ["seed", "mean", "std"]

        self.opt_attr_map = opt_attr_map
        self.opt_input_map = opt_input_map
        self.opt_init_map = opt_init_map

W
wangguanqun 已提交
354
    def parse_entry(self, varname, program_id, context):
355
        main_program, startup_program, idx = get_program_by_id(
356 357
            context, program_id
        )
W
wangguanqun 已提交
358
        for op in main_program.global_block().ops:
Z
ziyoujiyi 已提交
359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
            if not is_distributed_sparse_op(op) and not is_sparse_op(op):
                continue

            param_name = op.input("W")[0]

            if param_name == varname and op.type == "lookup_table":
                self.entry = op.attr('entry')
                break

            if param_name == varname and op.type == "lookup_table_v2":
                self.entry = "none"
                break

    def get_shard(self, total_dim, shard_num, pserver_id):
        blocksize = int(total_dim / shard_num + 1)

        if blocksize * (pserver_id + 1) <= total_dim:
            return blocksize
        else:
            if blocksize * pserver_id < total_dim:
                return total_dim - blocksize * pserver_id
            else:
                return 0

    def get_initializer_attr(self, value_name, o_startup_program):
        l_in = "&"
        attr_str = ""

        origin_var_name = value_name
388
        # print("get_initializer_attr param name:", value_name)
Z
ziyoujiyi 已提交
389
        for op in o_startup_program.global_block().ops:
390 391 392 393
            if (
                op.type in self.opt_init_map.keys()
                and origin_var_name == op.output("Out")[0]
            ):
Z
ziyoujiyi 已提交
394
                init_attr = [op.type]
395
                # print("get_initializer_attr op type:", op.type)
Z
ziyoujiyi 已提交
396
                for attr in self.opt_init_map[op.type]:
397
                    # print("get_initializer_attr opt_init_map attr:", attr)
Z
ziyoujiyi 已提交
398
                    init_attr.append(str(op.attr(attr)))
399
                    # print("get_initializer_attr op attr:", str(op.attr(attr)))
Z
ziyoujiyi 已提交
400 401 402 403
                attr_str = l_in.join(init_attr)
                break
        return attr_str

W
wangguanqun 已提交
404 405 406 407 408 409
    def parse_by_optimizer(self, ctx, context):
        grad_name = ctx.origin_varnames()[0]
        is_sparse = ctx.is_sparse()
        size = ctx.sections()[0]
        single_dim = ctx.sections()[1] if ctx.is_sparse() else 1
        adam_d2sum = context["user_defined_strategy"].adam_d2sum
410 411
        # print("parse_by_optimizer table_id:{} is_datanorm:{}".format(
        #     ctx.table_id(), ctx.is_datanorm_table()))
W
wangguanqun 已提交
412

413
        main_program, startup_program, idx = get_program_by_id(
414 415
            context, ctx.program_id()
        )
Z
ziyoujiyi 已提交
416 417 418
        pserver_id = get_role_id(context['role_maker'])
        pserver_num = len(get_ps_endpoints(context['role_maker']))
        optimizer_ops = get_optimize_ops(main_program)
419 420
        # print("the one ps optimizer_ops:", optimizer_ops)
        # print("the one ps parse_by_optimizer grad_name:", grad_name)
Z
ziyoujiyi 已提交
421 422 423 424
        oop = None

        for op in optimizer_ops:
            if ("Param" in op.input_names) and (
425 426 427
                op.input("Param")[0]
                == context['grad_name_to_param_name'][grad_name]
            ):
Z
ziyoujiyi 已提交
428 429 430 431 432 433 434 435 436 437 438 439
                oop = op
                break

        if oop is None:
            raise ValueError("can not find optimizer for {}".format(grad_name))

        params = []
        dims = []
        attrs = []
        initializers = []

        self.trainer_num = get_trainers(context['role_maker'])
W
wangguanqun 已提交
440 441
        self.table_num = size
        self.table_dim = single_dim
Z
ziyoujiyi 已提交
442

443
        if oop.type != 'adam' and adam_d2sum:
Z
ziyoujiyi 已提交
444 445 446 447 448 449 450 451 452 453 454
            print('optimization algorithm is not adam, set adam_d2sum False')
            adam_d2sum = False
        print("adam_d2sum:", adam_d2sum)
        if context['ps_mode'] == DistributedMode.GEO:
            param_varnames = self.opt_input_map["sum"]
            attr_varnames = self.opt_attr_map["sum"]
            self.accessor_class = "sum"
        elif context['use_ps_gpu'] and is_sparse:
            param_varnames = self.opt_input_map["naive_adagrad"]
            attr_varnames = self.opt_attr_map["naive_adagrad"]
            self.accessor_class = "sgd"
W
wangguanqun 已提交
455 456 457 458 459
        elif ctx.is_datanorm_table():
            param_varnames = self.opt_input_map["summary"]
            attr_varnames = self.opt_attr_map["summary"]
            self.accessor_class = "summary"
        elif adam_d2sum and not is_sparse:
Z
ziyoujiyi 已提交
460 461 462 463
            param_varnames = self.opt_input_map["adam_d2sum"]
            attr_varnames = self.opt_attr_map["adam_d2sum"]
            self.accessor_class = "adam_d2sum"
        else:
464 465
            if oop.type != 'sgd' and oop.type != 'adam':
                raise ValueError(
466 467
                    "The dense optimizer in PS is only supported SGD or Adam!"
                )
Z
ziyoujiyi 已提交
468 469 470 471 472 473 474
            param_varnames = self.opt_input_map[oop.type]
            attr_varnames = self.opt_attr_map[oop.type]
            self.accessor_class = oop.type

        for (formal_name, shape) in param_varnames:
            params.append(formal_name)
            if self.accessor_class == "adam_d2sum":
475
                # for dims
Z
ziyoujiyi 已提交
476 477
                if shape is None:
                    if is_sparse:
W
wangguanqun 已提交
478
                        shape = single_dim
Z
ziyoujiyi 已提交
479
                    else:
W
wangguanqun 已提交
480
                        shape = self.get_shard(size, pserver_num, pserver_id)
Z
ziyoujiyi 已提交
481 482
                dims.append(shape)

483
                # for initializers
Z
ziyoujiyi 已提交
484
                if formal_name == "Param" or formal_name == "LearningRate":
485 486 487 488 489 490 491 492
                    param = main_program.global_block().vars[
                        oop.input(formal_name)[0]
                    ]
                    # TODO: for dense learning_rate, can be different from sparse lr
                    if (
                        formal_name == "LearningRate"
                        and param.name != "learning_rate_" + str(idx)
                    ):
Z
ziyoujiyi 已提交
493 494
                        warnings.warn("will support decay soon")
                        param = main_program.global_block().vars[
495 496
                            "learning_rate_" + str(idx)
                        ]
Z
ziyoujiyi 已提交
497

498
                    initializer = self.get_initializer_attr(
499 500
                        param.name, startup_program
                    )
Z
ziyoujiyi 已提交
501 502 503 504 505 506 507 508 509
                elif formal_name == "MomentDecayRate":
                    initializer = "fill_constant&0.99"
                elif formal_name == "AdaDecayRate":
                    initializer = "fill_constant&0.9999"
                elif formal_name == "AdaEpsilon":
                    initializer = "fill_constant&1.0e-8"
                else:
                    initializer = "fill_constant&0"
                initializers.append(initializer)
W
wangguanqun 已提交
510
            elif self.accessor_class == "summary":
511
                # for dims
W
wangguanqun 已提交
512 513 514 515 516 517 518
                if shape is None:
                    if is_sparse:
                        shape = single_dim
                    else:
                        shape = self.get_shard(size, pserver_num, pserver_id)
                dims.append(shape)

519
                # for initializers
W
wangguanqun 已提交
520
                if formal_name == "Param":
521 522 523
                    param = main_program.global_block().vars[
                        oop.input(formal_name)[0]
                    ]
W
wangguanqun 已提交
524

525
                    initializer = self.get_initializer_attr(
526 527
                        param.name, startup_program
                    )
W
wangguanqun 已提交
528
                elif formal_name == "SummaryDecayRate":
529
                    initializer = "fill_constant&0.999999"
W
wangguanqun 已提交
530 531 532
                else:
                    initializer = "fill_constant&0"
                initializers.append(initializer)
Z
ziyoujiyi 已提交
533 534 535 536 537 538
            else:
                if formal_name == "G2Sum":
                    dims.append(1)
                    initializer = "fill_constant&0"
                    initializers.append(initializer)
                else:
539 540 541 542 543 544 545
                    param = main_program.global_block().vars[
                        oop.input(formal_name)[0]
                    ]
                    if (
                        formal_name == "LearningRate"
                        and param.name != "learning_rate_" + str(idx)
                    ):
Z
ziyoujiyi 已提交
546 547
                        warnings.warn("will support decay soon")
                        param = main_program.global_block().vars[
548 549
                            "learning_rate_" + str(idx)
                        ]
Z
ziyoujiyi 已提交
550 551 552

                    if shape is None:
                        if is_sparse:
W
wangguanqun 已提交
553
                            shape = single_dim
Z
ziyoujiyi 已提交
554
                        else:
555 556 557
                            shape = self.get_shard(
                                size, pserver_num, pserver_id
                            )
Z
ziyoujiyi 已提交
558 559
                    dims.append(shape)

560
                    initializer = self.get_initializer_attr(
561 562
                        param.name, startup_program
                    )
Z
ziyoujiyi 已提交
563 564
                    initializers.append(initializer)

565 566 567 568
        if self.accessor_class == 'summary':
            datanorm_ops = get_datanorm_ops(main_program)
            for op in datanorm_ops:
                if ("BatchSize" in op.input_names) and (
569 570 571
                    op.input("BatchSize")[0]
                    == context['grad_name_to_param_name'][grad_name]
                ):
572 573 574
                    oop = op
                    break

Z
ziyoujiyi 已提交
575 576
        for (attr_varname, type_) in attr_varnames:
            value = oop.attr(attr_varname)
577
            attrs.append("&".join([attr_varname, str(value)]))
Z
ziyoujiyi 已提交
578 579 580 581 582 583

        self.params = params
        self.dims = dims
        self.initializers = initializers
        self.attrs = attrs

Z
ziyoujiyi 已提交
584 585 586 587 588 589 590 591 592 593 594 595
    # CommonAccessorParameter common
    def _set(self, proto):
        proto.name = self.accessor_class
        proto.table_name = self.table_name
        proto.params.extend(self.params)
        proto.dims.extend(self.dims)
        proto.initializers.extend(self.initializers)
        proto.entry = self.entry
        proto.trainer_num = self.trainer_num
        proto.sync = self.sync
        proto.table_num = self.table_num
        proto.table_dim = self.table_dim
596
        proto.attr = "#".join(self.attrs)
Z
ziyoujiyi 已提交
597 598 599


class Tensor:
Z
ziyoujiyi 已提交
600 601 602 603
    def __init__(self, tesnor_dcit):
        self.tensor_dict = tesnor_dcit

    def _set(self, tensor_proto):
604
        tensor_proto.main_program_id = self.tensor_dict.get(
605 606
            "main_program_id", 0
        )
Z
ziyoujiyi 已提交
607
        tensor_proto.startup_program_id = self.tensor_dict.get(
608 609
            "startup_program_id", 0
        )
Z
ziyoujiyi 已提交
610 611 612
        tensor_proto.feed_var_name = self.tensor_dict.get("feed_var_name", '')
        tensor_proto.fetch_var_name = self.tensor_dict.get("fetch_var_name", '')
        tensor_proto.tensor_table_class = self.tensor_dict.get(
613 614
            "tensor_table_class", ''
        )
Z
ziyoujiyi 已提交
615 616 617 618 619 620 621


class Table:
    def __init__(self):
        self.table_class = None
        self.shard_num = -1
        self.type = None
Z
ziyoujiyi 已提交
622 623 624
        self.accessor = Accessor()
        self.shard_num = 256
        self.common = CommonAccessor()
Z
ziyoujiyi 已提交
625 626
        self.tensor = None

Z
ziyoujiyi 已提交
627 628
    def _set(self, table_proto):
        pass
Z
ziyoujiyi 已提交
629 630


Z
ziyoujiyi 已提交
631 632
class BarrierTable(Table):
    def __init__(self, context, idx):
633
        super().__init__()
Z
ziyoujiyi 已提交
634 635 636 637 638 639 640 641 642 643 644 645 646 647 648
        self.type = None
        self.shard_num = 256
        self.accessor.accessor_class = 'CommMergeAccessor'
        self.common.attrs = ""
        self.common.dims = []
        self.common.params = []
        self.is_heter_ps_mode = context['is_heter_ps_mode']
        self.role_maker = context['role_maker']
        self.idx = idx
        self.is_sync = context['is_sync']

    def _set(self, table_proto):
        table_proto.table_id = self.idx
        table_proto.table_class = 'BarrierTable'
        table_proto.shard_num = 256
Z
ziyoujiyi 已提交
649
        table_proto.type = the_one_ps_pb2.PS_OTHER_TABLE
Z
ziyoujiyi 已提交
650 651 652 653 654 655 656 657 658 659 660 661 662 663

        table_proto.accessor.accessor_class = "CommMergeAccessor"
        table_proto.accessor.fea_dim = 0
        table_proto.accessor.embedx_dim = 0

        table_proto.common.name = ""
        table_proto.common.table_name = "barrier_table"
        table_proto.common.sync = self.is_sync
        table_proto.common.entry = 'none'

        trainer_num = get_trainers(self.role_maker)
        if self.is_heter_ps_mode:
            trainer_num += len(self.role_maker._get_heter_worker_endpoints())
        table_proto.common.trainer_num = trainer_num
Z
ziyoujiyi 已提交
664 665


Z
ziyoujiyi 已提交
666 667
class TensorTable(Table):
    def __init__(self, idx, tensor_dict, role_maker):
668
        super().__init__()
Z
ziyoujiyi 已提交
669 670 671
        self.idx = idx
        self.tensor_dict = tensor_dict
        self.role_maker = role_maker
Z
ziyoujiyi 已提交
672

Z
ziyoujiyi 已提交
673 674
    def _set(self, table_proto):
        table_proto.table_id = self.idx
Z
ziyoujiyi 已提交
675
        table_proto.type = the_one_ps_pb2.PS_OTHER_TABLE
Z
ziyoujiyi 已提交
676
        table_proto.table_class = self.tensor_dict.get("tensor_table_class", '')
Z
ziyoujiyi 已提交
677

Z
ziyoujiyi 已提交
678
        table_proto.accessor.accessor_class = "CommMergeAccessor"
Z
ziyoujiyi 已提交
679

680
        table_proto.common.table_name = self.tensor_dict.get(
681 682
            "feed_var_name", ''
        )
Z
ziyoujiyi 已提交
683
        table_proto.common.trainer_num = get_trainers(self.role_maker)
Z
ziyoujiyi 已提交
684

Z
ziyoujiyi 已提交
685 686
        tensor = Tensor(self.tensor_dict)
        tensor._set(table_proto.tensor)
Z
ziyoujiyi 已提交
687 688


Z
ziyoujiyi 已提交
689 690
class SparseTable(Table):
    def __init__(self, context, send_ctx):
691
        super().__init__()
Z
ziyoujiyi 已提交
692 693 694 695 696
        self.context = context
        self.ctx = send_ctx
        self.type = None
        self.table_class = 'MemorySparseTable'
        self.accessor = Accessor()
Z
ziyoujiyi 已提交
697

Z
ziyoujiyi 已提交
698 699
    def _set(self, table_proto):
        ctx = self.ctx
700 701 702
        if (
            ctx.is_tensor_table()
            or len(ctx.origin_varnames()) < 1
703
            or (not ctx.is_sparse())
704
        ):
Z
ziyoujiyi 已提交
705 706 707
            return
        table_proto.table_id = ctx.table_id()
        table_proto.table_class = self.table_class
Z
ziyoujiyi 已提交
708
        table_proto.type = the_one_ps_pb2.PS_SPARSE_TABLE
Z
ziyoujiyi 已提交
709
        table_proto.shard_num = self.shard_num
710
        if table_proto.sparse_table_cache_file_num > len(
711 712
            get_ps_endpoints(self.context['role_maker'])
        ):
713
            table_proto.sparse_table_cache_file_num = len(
714 715
                get_ps_endpoints(self.context['role_maker'])
            )
Z
ziyoujiyi 已提交
716 717

        self.common.table_name = self.context['grad_name_to_param_name'][
718 719
            ctx.origin_varnames()[0]
        ]
Z
ziyoujiyi 已提交
720

721
        self.common.parse_by_optimizer(ctx, self.context)
722 723 724
        self.common.parse_entry(
            self.common.table_name, ctx.program_id(), self.context
        )
725 726 727 728
        self.common.sync = True if self.context['is_sync'] else False

        self.common._set(table_proto.common)

Z
ziyoujiyi 已提交
729 730
        print('new table_name: {}'.format(self.common.table_name))
        all_table_proto = self.context[
731 732
            "user_defined_strategy"
        ].sparse_table_configs
Z
ziyoujiyi 已提交
733 734 735 736 737
        usr_table_proto = all_table_proto.add()
        for proto in all_table_proto:
            if proto.table_name == self.common.table_name:
                usr_table_proto = proto
                break
738 739 740 741 742
        if usr_table_proto.HasField("table_class"):
            table_proto.table_class = usr_table_proto.table_class
        else:
            table_proto.table_class = 'MemorySparseTable'
            warnings.warn("The PS mode must use MemorySparseTable.")
Z
ziyoujiyi 已提交
743 744 745
        if usr_table_proto.HasField("shard_num"):
            table_proto.shard_num = usr_table_proto.shard_num
        else:
746 747 748 749 750 751 752 753 754 755
            if self.context['use_ps_gpu']:
                table_proto.shard_num = 37
                warnings.warn(
                    "The shard_num of sparse table is not set, use default value 37 in gpups."
                )
            else:
                table_proto.shard_num = 1000
                warnings.warn(
                    "The shard_num of sparse table is not set, use default value 1000 in cpups."
                )
Z
ziyoujiyi 已提交
756

757
        if usr_table_proto.HasField("enable_sparse_table_cache"):
758 759 760
            table_proto.enable_sparse_table_cache = (
                usr_table_proto.enable_sparse_table_cache
            )
761
        if usr_table_proto.HasField("sparse_table_cache_rate"):
762 763 764
            table_proto.sparse_table_cache_rate = (
                usr_table_proto.sparse_table_cache_rate
            )
765
        if usr_table_proto.HasField("sparse_table_cache_file_num"):
766 767 768
            table_proto.sparse_table_cache_file_num = (
                usr_table_proto.sparse_table_cache_file_num
            )
769 770 771 772 773
        if usr_table_proto.HasField("enable_revert"):
            table_proto.enable_revert = usr_table_proto.enable_revert
        if usr_table_proto.HasField("shard_merge_rate"):
            table_proto.shard_merge_rate = usr_table_proto.shard_merge_rate

Z
ziyoujiyi 已提交
774 775
        if usr_table_proto.accessor.ByteSize() == 0:
            warnings.warn(
776 777
                "The accessor of sparse table is not set, use default value."
            )
Z
ziyoujiyi 已提交
778

Z
ziyoujiyi 已提交
779
        table_proto.accessor.ParseFromString(
780 781 782 783 784 785 786 787 788
            usr_table_proto.accessor.SerializeToString()
        )
        self.accessor._set(
            table_proto.accessor,
            self.common.table_name,
            ctx.program_id(),
            self.context,
            self.common,
        )
Z
ziyoujiyi 已提交
789

790 791 792 793 794 795
        check_embedding_dim(
            table_proto.accessor,
            self.common.table_name,
            ctx.program_id(),
            self.context,
        )
Z
ziyoujiyi 已提交
796 797


Z
ziyoujiyi 已提交
798 799
class GeoSparseTable(SparseTable):
    def __init__(self, context, send_ctx):
800
        super().__init__(context, send_ctx)
801
        self.table_class = "MemorySparseGeoTable"
Z
ziyoujiyi 已提交
802 803 804 805 806
        if self.context['ps_mode'] != DistributedMode.GEO:
            raise ValueError("not geo sparse table!")

    def _set(self, table_proto):
        ctx = self.ctx
807 808 809
        if (
            ctx.is_tensor_table()
            or len(ctx.origin_varnames()) < 1
810
            or (not ctx.is_sparse())
811
        ):
Z
ziyoujiyi 已提交
812 813 814
            return
        table_proto.table_id = ctx.table_id()
        table_proto.table_class = self.table_class
Z
ziyoujiyi 已提交
815
        table_proto.type = the_one_ps_pb2.PS_SPARSE_TABLE
Z
ziyoujiyi 已提交
816 817 818 819 820 821 822
        table_proto.shard_num = self.shard_num

        table_proto.accessor.accessor_class = 'CommMergeAccessor'
        table_proto.accessor.fea_dim = ctx.sections()[0]
        table_proto.accessor.embedx_dim = ctx.sections()[1]

        self.common.table_name = self.context['grad_name_to_param_name'][
823 824
            ctx.origin_varnames()[0]
        ]
Z
ziyoujiyi 已提交
825
        self.common.parse_by_optimizer(ctx, self.context)
826 827 828
        self.common.parse_entry(
            self.common.table_name, ctx.program_id(), self.context
        )
Z
ziyoujiyi 已提交
829 830 831 832 833 834
        self.common.sync = False
        self.common._set(table_proto.common)


class DenseTable(Table):
    def __init__(self, context, send_ctx):
835
        super().__init__()
Z
ziyoujiyi 已提交
836 837 838
        self.context = context
        self.ctx = send_ctx
        self.accessor = Accessor()
Z
ziyoujiyi 已提交
839

Z
ziyoujiyi 已提交
840 841
    def _set(self, table_proto):
        ctx = self.ctx
842 843 844
        if (
            ctx.is_tensor_table()
            or len(ctx.origin_varnames()) < 1
845
            or (ctx.is_sparse())
846
        ):
Z
ziyoujiyi 已提交
847 848 849 850
            return

        table_proto.table_id = ctx.table_id()

Z
ziyoujiyi 已提交
851
        table_proto.type = the_one_ps_pb2.PS_DENSE_TABLE
852
        table_proto.table_class = "MemoryDenseTable"
Z
ziyoujiyi 已提交
853 854 855 856 857 858 859 860
        table_proto.shard_num = 256

        table_proto.accessor.accessor_class = 'CommMergeAccessor'
        table_proto.accessor.fea_dim = ctx.sections()[0]
        table_proto.accessor.embedx_dim = 1

        self.common.table_name = "MergedDense"
        self.common.parse_by_optimizer(ctx, self.context)
861 862 863
        self.common.parse_entry(
            self.common.table_name, ctx.program_id(), self.context
        )
Z
ziyoujiyi 已提交
864 865 866 867 868 869
        self.common.sync = True if self.context['is_sync'] else False

        self.common._set(table_proto.common)


class Server:
Z
ziyoujiyi 已提交
870
    def __init__(self):
Z
ziyoujiyi 已提交
871
        pass
Z
ziyoujiyi 已提交
872

Z
ziyoujiyi 已提交
873 874
    def _set(self):
        pass
Z
ziyoujiyi 已提交
875 876


Z
ziyoujiyi 已提交
877 878
class DownpourServer(Server):
    def __init__(self):
879
        super().__init__()
Z
ziyoujiyi 已提交
880 881 882

    def _set(self):
        pass
Z
ziyoujiyi 已提交
883 884 885 886


class Worker:
    def __init__(self):
Z
ziyoujiyi 已提交
887
        pass
Z
ziyoujiyi 已提交
888

Z
ziyoujiyi 已提交
889 890
    def _set(self):
        pass
Z
ziyoujiyi 已提交
891 892


Z
ziyoujiyi 已提交
893 894
class DownpourWorker(Worker):
    def __init__(self):
895
        super().__init__()
Z
ziyoujiyi 已提交
896 897 898

    def _set(self):
        pass
Z
ziyoujiyi 已提交
899 900 901


class fsClient:
Z
ziyoujiyi 已提交
902 903 904 905 906 907 908 909 910 911 912 913
    def __init__(self, fs_client_param):
        self.fs_client_param = fs_client_param

    def _set(self, proto):
        if not text_format.MessageToString(self.fs_client_param):
            return
        proto.uri = self.fs_client_param.uri
        proto.user = self.fs_client_param.user
        proto.passwd = self.fs_client_param.passwd
        proto.hadoop_bin = self.fs_client_param.hadoop_bin


914
class PsDescBuilder:
Z
ziyoujiyi 已提交
915 916 917 918 919 920
    def __init__(self, context):
        self.context = context
        self.is_sync = context['is_sync']
        self.ps_mode = context['ps_mode']
        self.is_heter_ps_mode = context['is_heter_ps_mode']
        self.use_ps_gpu = context['use_ps_gpu']
921
        self.barrier_table_id = None
922

Z
ziyoujiyi 已提交
923
        self.send_ctx = get_the_one_send_context(
924 925
            self.context, split_dense_table=self.is_heter_ps_mode
        )
Z
ziyoujiyi 已提交
926 927 928 929 930 931 932 933 934

        self.tensor_table_dict = {}  # TODO
        self._server_sub_program = []

        self.tables = self._get_tables()

        self.service = self._get_service()
        self.fs_client = self._get_fs_client()

Z
ziyoujiyi 已提交
935
        self.ps_desc = the_one_ps_pb2.PSParameter()
936
        self.fl_desc = the_one_ps_pb2.FLParameter()
Z
ziyoujiyi 已提交
937 938 939 940 941 942 943

    def _get_tensor_tables(self):
        program_idx = 0
        if not self.tensor_table_dict:
            self._server_sub_program.append(Program().desc)
        tables = []
        for table_name in self.tensor_table_dict:
944 945 946 947 948
            tables.append(
                globals()['TensorTable'](
                    len(tables), tensor_dict, self.context['role_maker']
                )
            )
Z
ziyoujiyi 已提交
949 950 951 952 953 954
            program_idx += 1
        return tables

    def _get_tables(self):
        tables = []
        for idx, (name, ctx) in enumerate(self.send_ctx.items()):
955
            print("idx, name, ctx:", idx, name, ctx)
Z
ziyoujiyi 已提交
956 957
            if ctx.is_sparse():
                if self.ps_mode == DistributedMode.GEO:
958 959 960 961 962 963 964
                    if (
                        self.context['local_sparse']
                        and name[:-5] in self.context['local_sparse']
                    ) or (not self.context['local_sparse']):
                        tables.append(
                            globals()['GeoSparseTable'](self.context, ctx)
                        )
Z
ziyoujiyi 已提交
965
                    else:
966 967 968
                        tables.append(
                            globals()['SparseTable'](self.context, ctx)
                        )
Z
ziyoujiyi 已提交
969 970 971 972 973 974 975 976 977 978 979 980
                else:
                    tables.append(globals()['SparseTable'](self.context, ctx))
            else:
                tables.append(globals()['DenseTable'](self.context, ctx))
        self.tensor_tables = self._get_tensor_tables()
        tables.extend(self.tensor_tables)
        tables.append(globals()['BarrierTable'](self.context, len(tables)))
        return tables

    def _get_service(self):
        if self.use_ps_gpu:
            return GpuService()
Z
ziyoujiyi 已提交
981
        else:
Z
ziyoujiyi 已提交
982 983 984 985 986
            return Service()

    def _get_fs_client(self):
        return fsClient(self.context["user_defined_strategy"].fs_client_param)

987 988 989
    def build_fl_client_desc(self, client_info):
        pass

Z
ziyoujiyi 已提交
990 991
    def build_worker_desc(self):
        for table in self.tables:
992 993
            table_proto = (
                self.ps_desc.worker_param.downpour_worker_param.downpour_table_param.add()
Z
ziyoujiyi 已提交
994 995
            )
            table._set(table_proto)
996 997
            table_proto = (
                self.ps_desc.server_param.downpour_server_param.downpour_table_param.add()
Z
ziyoujiyi 已提交
998 999
            )
            table._set(table_proto)
1000 1001
            if type(table) == BarrierTable and self.barrier_table_id is None:
                self.barrier_table_id = table.idx
Z
ziyoujiyi 已提交
1002
        self.service._set(
1003 1004
            self.ps_desc.server_param.downpour_server_param.service_param
        )
1005
        self.fs_client._set(self.ps_desc.fs_client_param)
Z
ziyoujiyi 已提交
1006 1007 1008
        return text_format.MessageToString(self.ps_desc)

    def build_server_desc(self):
1009
        self.sparse_table_maps = {}
Z
ziyoujiyi 已提交
1010
        for table in self.tables:
1011 1012
            table_proto = (
                self.ps_desc.server_param.downpour_server_param.downpour_table_param.add()
Z
ziyoujiyi 已提交
1013 1014
            )
            table._set(table_proto)
1015 1016 1017 1018
            if (
                table_proto.type == the_one_ps_pb2.PS_SPARSE_TABLE
                and table_proto.common is not None
            ):
Z
ziyoujiyi 已提交
1019
                self.sparse_table_maps[
1020 1021
                    table_proto.common.table_name
                ] = table_proto.table_id
Z
ziyoujiyi 已提交
1022 1023

        self.service._set(
1024 1025
            self.ps_desc.server_param.downpour_server_param.service_param
        )
Z
ziyoujiyi 已提交
1026 1027
        self.fs_client._set(self.ps_desc.fs_client_param)
        return text_format.MessageToString(self.ps_desc)
Z
ziyoujiyi 已提交
1028 1029 1030 1031


class TheOnePSRuntime(RuntimeBase):
    def __init__(self):
1032
        super().__init__()
Z
ziyoujiyi 已提交
1033 1034
        self._communicator = None
        self._server = None
W
wangguanqun 已提交
1035
        self._worker = core.DistFleetWrapper()
1036
        self._coordinator = None
Z
ziyoujiyi 已提交
1037 1038
        self._server_sub_program = []
        self._heter_client = None
1039
        self._send_ctx = None
Z
ziyoujiyi 已提交
1040 1041 1042 1043

    def _set_basic_info(self, context):
        self.context = context
        self.role_maker = context["role_maker"]
1044 1045
        self.role_id = get_role_id(self.role_maker)
        self.debug = bool(int(os.getenv("PSERVER_DEBUG", "0")))
W
wangguanqun 已提交
1046

Z
ziyoujiyi 已提交
1047
        self.origin_main_program = context["origin_main_program"]
1048 1049 1050
        self.origin_main_programs = context.get(
            "origin_main_programs", [self.origin_main_program]
        )
Z
ziyoujiyi 已提交
1051 1052
        self.context["origin_main_programs"] = self.origin_main_programs
        self.context["origin_startup_programs"] = context.get(
1053 1054
            'origin_startup_programs', [context['origin_startup_program']]
        )
Z
ziyoujiyi 已提交
1055
        self.context[
1056 1057
            'is_heter_ps_mode'
        ] = self.role_maker._is_heter_parameter_server_mode
Z
ziyoujiyi 已提交
1058
        self.is_heter_ps_mode = self.context['is_heter_ps_mode']
1059
        self.context['trainer'] = TrainerRuntimeConfig(
1060 1061
            context['valid_strategy']
        )
Z
ziyoujiyi 已提交
1062
        self.context['ps_mode'] = self.context['trainer'].mode
W
wangguanqun 已提交
1063
        self.context['use_ps_gpu'] = context['valid_strategy'].a_sync_configs[
1064 1065 1066 1067 1068
            'use_ps_gpu'
        ]
        self.context['is_sync'] = (
            True if self.context['ps_mode'] == DistributedMode.SYNC else False
        )
Z
ziyoujiyi 已提交
1069
        self.context['grad_name_to_param_name'] = {}
W
wangguanqun 已提交
1070
        self.context['tensor_table'] = {}
1071 1072
        # FL
        self.context['local_sparse'] = context[
1073 1074
            "user_defined_strategy"
        ].trainer_desc_configs["local_sparse"]
1075
        self.context['remote_sparse'] = context[
1076 1077 1078 1079 1080 1081 1082
            "user_defined_strategy"
        ].trainer_desc_configs["remote_sparse"]
        print(
            "fl-ps > local_sparse: {}, remote_sparse: {}".format(
                self.context['local_sparse'], self.context['remote_sparse']
            )
        )
1083

W
wangguanqun 已提交
1084
        build_var_distributed(self.context)
Z
ziyoujiyi 已提交
1085

1086 1087
        self.trainer_endpoints = get_trainer_endpoints(self.role_maker)

1088
        self.endpoints = get_ps_endpoints(self.role_maker)
Z
ziyoujiyi 已提交
1089
        self.string_hosts = []
1090
        for idx, ep in enumerate(self.endpoints):
Z
ziyoujiyi 已提交
1091
            host, port = ep.split(":")
W
wangguanqun 已提交
1092
            pshost = core.PSHost(host, int(port), idx)
Z
ziyoujiyi 已提交
1093 1094
            self.string_hosts.append(pshost.serialize_to_string())

1095 1096 1097 1098 1099 1100 1101
        self.with_coordinator = self.role_maker._with_coordinator
        self.coordinator_hosts = []
        if self.with_coordinator:
            print("fl-ps > all ps addrs: {}".format(self.string_hosts))
            coordinator_endpoints = self.role_maker._get_coordinator_endpoints()
            for idx, ep in enumerate(coordinator_endpoints):
                ip, port = ep.split(":")
W
wangguanqun 已提交
1102
                pshost = core.PSHost(ip, int(port), idx)
1103 1104
                self.coordinator_hosts.append(pshost.serialize_to_string())

Z
ziyoujiyi 已提交
1105 1106
        self.ps_desc_builder = PsDescBuilder(self.context)

1107
    def _init_all_params(self, scopes, send_ctx, recv_map):
1108
        all_var_names = []
1109 1110 1111 1112 1113 1114 1115
        for name, ctx in send_ctx.items():
            if ctx.is_sparse():
                continue
            _, _, idx = get_program_by_id(self.context, ctx.program_id())
            scope = scopes[idx]
            table_id = ctx.table_id()
            var_names = recv_map[table_id]
1116
            # print("init params:", idx, table_id, var_names)
1117
            self._worker.push_dense_params(scope, table_id, var_names)
1118 1119
            all_var_names.extend(var_names)
        return all_var_names
1120 1121

    def _pull_all_dense(self, scopes, send_ctx, recv_map):
1122
        all_var_names = []
1123 1124 1125 1126 1127 1128 1129
        for name, ctx in send_ctx.items():
            if ctx.is_sparse():
                continue
            _, _, idx = get_program_by_id(self.context, ctx.program_id())
            scope = scopes[idx]
            table_id = ctx.table_id()
            var_names = recv_map[table_id]
1130
            # print("pull all dense:", idx, table_id, var_names)
1131
            self._worker.pull_dense_params(scope, table_id, var_names)
1132 1133
            all_var_names.extend(var_names)
        return all_var_names
1134

1135
    def _init_params(self, program, scope, send_ctx, recv_map):
1136
        all_var_names = []
1137 1138 1139 1140 1141 1142 1143 1144 1145
        for name, ctx in send_ctx.items():
            if ctx.is_sparse():
                continue
            if ctx.program_id() != id(program):
                continue
            table_id = ctx.table_id()
            var_names = recv_map[table_id]
            # print("init params:", table_id, var_names)
            self._worker.push_dense_params(scope, table_id, var_names)
1146 1147
            all_var_names.extend(var_names)
        return all_var_names
1148

1149
    def _pull_dense(self, program, scope, send_ctx, recv_map):
1150
        all_var_names = []
1151 1152 1153 1154 1155 1156 1157 1158 1159
        for name, ctx in send_ctx.items():
            if ctx.is_sparse():
                continue
            if ctx.program_id() != id(program):
                continue
            table_id = ctx.table_id()
            var_names = recv_map[table_id]
            # print("pull dense:", table_id, var_names)
            self._worker.pull_dense_params(scope, table_id, var_names)
1160 1161
            all_var_names.extend(var_names)
        return all_var_names
1162 1163

    def _init_worker(self, scopes=None):
Z
ziyoujiyi 已提交
1164
        worker_desc = self.ps_desc_builder.build_worker_desc()
Z
ziyoujiyi 已提交
1165 1166 1167 1168 1169 1170
        if self.context['use_ps_gpu']:
            main_program = self.context['loss'].block.program
            if not main_program._fleet_opt:
                main_program._fleet_opt = {}
            main_program._fleet_opt["use_ps_gpu"] = True
            gpus_env = os.getenv("FLAGS_selected_gpus")
1171 1172
            gpus_env = [int(s) for s in gpus_env.split(",")]
            main_program._fleet_opt["worker_places"] = gpus_env
W
wangguanqun 已提交
1173
            PSGPU = core.PSGPU()
1174
            PSGPU.init_gpu_ps(gpus_env)
Z
ziyoujiyi 已提交
1175 1176 1177 1178

        def sync_strategy_envs():
            kwargs = {}
            kwargs[
1179 1180
                "pserver_endpoints"
            ] = self.role_maker._get_pserver_endpoints()
Z
ziyoujiyi 已提交
1181 1182 1183 1184
            kwargs["trainer_id"] = self.role_maker._worker_index()
            return kwargs

        dense_map = get_the_one_recv_context(
1185 1186
            self.context, split_dense_table=self.is_heter_ps_mode
        )
Z
ziyoujiyi 已提交
1187 1188 1189
        send_ctx = get_the_one_send_context(
            self.context,
            split_dense_table=self.is_heter_ps_mode,
1190 1191
            ep_list=self.endpoints,
        )
1192
        self._send_ctx = send_ctx
Z
ziyoujiyi 已提交
1193 1194
        trainer_config = self.context['trainer']

1195 1196
        if self.debug:
            print("worker_desc: \n{}".format(worker_desc))
Z
ziyoujiyi 已提交
1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207
            print("communicator send_ctx:")
            for key in send_ctx:
                print("{}: {}".format(key, send_ctx[key]))
            for key in dense_map:
                print("{}: {}".format(key, dense_map[key]))

        kwargs = {}
        kwargs['need_global_step'] = "0"
        kwargs["trainer_id"] = self.role_maker._role_id()
        kwargs["trainers"] = self.role_maker._worker_num()

1208
        kwargs["barrier_table_id"] = self.ps_desc_builder.barrier_table_id
Z
ziyoujiyi 已提交
1209 1210 1211 1212 1213

        if self.context['ps_mode'] == DistributedMode.SYNC:
            sync_kwargs = sync_strategy_envs()
            kwargs.update(sync_kwargs)

W
wangguanqun 已提交
1214
        print("communicator config:", trainer_config.get_communicator_flags())
Z
ziyoujiyi 已提交
1215

1216
        self._worker.init_worker(worker_desc, self.string_hosts, self.role_id)
Z
ziyoujiyi 已提交
1217 1218 1219
        if not self.is_heter_ps_mode:
            self.trainer_endpoint = get_trainer_endpoint(self.role_maker)
            print("fl-ps > trainer_endpoint: {}".format(self.trainer_endpoint))
1220 1221 1222
        print("fl-ps > with_coordinator? {}".format(self.with_coordinator))
        print("fl-ps > coordinator addr: {}".format(self.coordinator_hosts))
        if self.with_coordinator:
1223 1224 1225
            self._worker.init_fl_worker(
                self.coordinator_hosts, self.role_id, self.trainer_endpoint
            )
1226

1227 1228 1229 1230
        if (
            self.context['ps_mode'] == DistributedMode.GEO
            or self.is_heter_ps_mode
        ):
1231
            self._communicator = Communicator(
1232 1233 1234 1235 1236 1237 1238 1239 1240
                trainer_config.mode,
                kwargs,
                trainer_config.get_communicator_flags(),
            )
            self._communicator.init_with_ctx(
                send_ctx,
                dense_map,
                worker_desc,
                self.string_hosts,
W
wangguanqun 已提交
1241
                paddle.static.global_scope(),
1242
            )
Z
ziyoujiyi 已提交
1243
        fleet.util.barrier()
1244 1245 1246

        # info = self._communicator.get_client_info()
        info = self._worker.get_client_info()
Z
ziyoujiyi 已提交
1247
        if isinstance(info, list) and len(info) > 0:
1248
            all_info = self.role_maker._all_gather(
1249 1250
                info[0]
            )  # 收集其他 client 的 service 地址
Z
ziyoujiyi 已提交
1251 1252 1253 1254
            # for unittest
            if not isinstance(all_info, list):
                warnings.warn("gloo may not initialize correctly")
                all_info = [all_info]
1255 1256 1257 1258 1259

            # self._communicator.set_clients(all_info)
            # self._communicator.create_client_to_client_connection()
            self._worker.set_clients(all_info)
            self._worker.create_client2client_connection()
Z
ziyoujiyi 已提交
1260 1261 1262 1263 1264 1265 1266 1267
            print('create c2c connection done')
        else:
            print('cannot create c2c connection')

        dist_strategy = self.context["valid_strategy"]

        is_test = bool(int(os.getenv("TEST_MODE", "0")))

1268 1269 1270 1271 1272
        if scopes is None:
            if len(self.origin_main_programs) > 1:
                raise ValueError(
                    "You must set the scope list when you have Multiple programs"
                )
W
wangguanqun 已提交
1273
            scopes = [paddle.static.global_scope()]
1274 1275 1276 1277
        if len(self.origin_main_programs) != len(scopes):
            raise VauleError("len(programs) != len(scopes)")

        self.scopes = scopes
Z
ziyoujiyi 已提交
1278
        if not is_test:
1279 1280
            if (
                self.context['ps_mode'] == DistributedMode.GEO
1281
                or self.is_heter_ps_mode
1282
            ):
1283
                self._communicator.init_params(dense_map)
1284
            else:
D
danleifeng 已提交
1285
                if not self.context['use_ps_gpu']:
1286
                    if self.role_id == 0:
1287
                        print("entering self._init_all_params()")
D
danleifeng 已提交
1288
                        self._init_all_params(scopes, send_ctx, dense_map)
1289

1290 1291
            fleet.util.barrier()  # 保证 0 号 worker 参数 push_dense_param over

D
danleifeng 已提交
1292
        if not self.context['use_ps_gpu']:
Z
ziyoujiyi 已提交
1293
            self._pull_all_dense(scopes, send_ctx, dense_map)
Z
ziyoujiyi 已提交
1294 1295
        fleet.util.barrier()

1296 1297
        if (
            self.context['ps_mode'] == DistributedMode.GEO
1298
            or self.is_heter_ps_mode
1299
        ):
1300 1301 1302 1303
            if not self._communicator.is_running():
                self._communicator.start()
            else:
                warnings.warn("communicator has been initialized, skip")
Z
ziyoujiyi 已提交
1304 1305 1306 1307 1308

        launch_barrier = dist_strategy.a_sync_configs["launch_barrier"]
        launch_barrier_flag = int(os.getenv("FLAGS_LAUNCH_BARRIER", "1"))
        if launch_barrier and launch_barrier_flag:
            wait_server_ready(self.role_maker._get_pserver_endpoints())
1309 1310 1311 1312
            if (
                self.is_heter_ps_mode
                and self.role_maker._get_next_trainers() != []
            ):
Z
ziyoujiyi 已提交
1313 1314 1315 1316 1317 1318 1319 1320
                wait_server_ready(self.role_maker._get_next_trainers())
            if self.is_heter_ps_mode:
                previous_trainers = []
                if self.role_maker._get_previous_trainers() != []:
                    previous_trainers = self.role_maker._get_previous_trainers()
                next_trainers = []
                if self.role_maker._get_next_trainers() != []:
                    next_trainers = self.role_maker._get_next_trainers()
1321
                self._heter_client = HeterClient(
1322 1323
                    next_trainers, previous_trainers, self.role_maker._role_id()
                )  # --> HeterClient::GetInstance
Z
ziyoujiyi 已提交
1324

1325
    def _init_coordinator(self, scopes=None):
1326
        if self._coordinator is None:
1327 1328 1329 1330
            self._coordinator = Coordinator(self.string_hosts)

        print(">>> curr node ip: {}".format(self.coordinator_hosts[0]))
        print(">>> all trainer endpoints: {}".format(self.trainer_endpoints))
1331 1332 1333
        self._coordinator.start_coordinator(
            self.coordinator_hosts[0], self.trainer_endpoints
        )
1334 1335

    def _make_fl_strategy(self):
1336
        if self._coordinator is None:
1337
            assert "Coordinator py object is null!"
1338 1339 1340
        else:
            self._coordinator.make_fl_strategy()

Z
ziyoujiyi 已提交
1341
    def _init_server(self, dirname=None, var_names=None, **kwargs):
Z
ziyoujiyi 已提交
1342
        server_desc = self.ps_desc_builder.build_server_desc()
Z
ziyoujiyi 已提交
1343 1344 1345 1346
        trainers = get_trainers(self.role_maker)
        if self.is_heter_ps_mode:
            trainers += len(self.role_maker._get_heter_worker_endpoints())

1347 1348
        if self.debug:
            print("server_desc: \n{}".format(server_desc))
W
wangguanqun 已提交
1349

W
wangguanqun 已提交
1350
        self._server = core.DistFleetWrapper()
1351 1352 1353 1354 1355 1356 1357
        self._server.init_server(
            server_desc,
            self.string_hosts,
            self.role_id,
            trainers,
            self._server_sub_program,
        )
Z
ziyoujiyi 已提交
1358

W
wangguanqun 已提交
1359
        dist_varnames = get_sparse_tablenames(self.origin_main_programs, True)
1360 1361 1362
        sparse_varnames = get_sparse_tablenames(
            self.origin_main_programs, False
        )
Z
ziyoujiyi 已提交
1363 1364 1365 1366 1367 1368 1369 1370 1371

        distributed_varnames = dist_varnames + sparse_varnames

        if var_names is None:
            load_varnames = distributed_varnames
        else:
            for var_name in var_names:
                if var_name not in distributed_varnames:
                    raise ValueError(
1372 1373 1374 1375
                        "fleet.init server can only load sparse variables in {}".format(
                            distributed_varnames
                        )
                    )
Z
ziyoujiyi 已提交
1376 1377 1378 1379 1380
            load_varnames = var_names

        if dirname is None or not load_varnames:
            return

Z
ziyoujiyi 已提交
1381
        sparse_table_maps = self.ps_desc_builder.sparse_table_maps
Z
ziyoujiyi 已提交
1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395

        dirname = os.path.normpath(dirname)
        pserver_id = self.role_maker._role_id()

        for var_name in load_varnames:
            table_id = sparse_table_maps[var_name]
            self._server.load_sparse(dirname, "0", table_id)

    def _run_server(self):
        ep = get_ps_endpoint(self.role_maker)
        host, port = ep.split(":")
        self._server.run_server(host, int(port))

    def _stop_worker(self):
1396 1397 1398
        if self.context['ps_mode'] == DistributedMode.GEO:
            self._communicator.stop()
        self._worker.stop_worker()
Z
ziyoujiyi 已提交
1399
        if self.is_heter_ps_mode:
1400
            assert (
1401
                self._heter_client is not None
1402
            ), "heter client should not be None in heterps mode"
Z
ziyoujiyi 已提交
1403 1404 1405 1406 1407 1408 1409 1410
            self._heter_client.stop()

    @staticmethod
    def __exclude_vars(exclude_var_names=[]):
        def is_valid(var):
            if var.name in exclude_var_names:
                return False

W
wangguanqun 已提交
1411
            from .utils.public import _get_varname_parts
1412

Z
ziyoujiyi 已提交
1413 1414 1415 1416
            origin_varname, _, _ = _get_varname_parts(var.name)
            if origin_varname.endswith("@GRAD"):
                return False

1417
            if origin_varname.startswith("learning_rate_"):
Z
ziyoujiyi 已提交
1418 1419
                return False

1420 1421 1422 1423 1424
            if (
                var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH
                or var.desc.type() == core.VarDesc.VarType.FETCH_LIST
                or var.desc.type() == core.VarDesc.VarType.READER
            ):
Z
ziyoujiyi 已提交
1425 1426 1427 1428 1429
                return False
            return var.persistable

        return is_valid

W
wangguanqun 已提交
1430 1431 1432 1433 1434 1435 1436
    def _get_inference_model_path(self, dirname):
        if dirname.startswith("afs:") or dirname.startswith("hdfs:"):
            model_path = "./dnn_plugin"
        else:
            model_path = os.path.join(dirname, "dnn_plugin")
        return model_path

1437 1438 1439
    def _ps_save_dense_params(
        self, executor, dirname, scope, program, var_names=None
    ):
1440
        dense_map = get_the_one_recv_context(
1441 1442
            self.context, split_dense_table=self.is_heter_ps_mode
        )
1443 1444 1445
        send_ctx = get_the_one_send_context(
            self.context,
            split_dense_table=self.is_heter_ps_mode,
1446 1447
            ep_list=self.endpoints,
        )
1448 1449 1450 1451 1452 1453
        if program is None or len(self.origin_main_programs) == 1:
            program = self.origin_main_programs[0]
        dense_var_names = self._pull_dense(program, scope, send_ctx, dense_map)
        save_var_names = dense_var_names if var_names is None else var_names
        vars = [program.global_block().var(i) for i in save_var_names]
        import paddle
1454

1455
        with paddle.static.scope_guard(scope):
1456 1457 1458 1459 1460 1461 1462 1463 1464 1465
            paddle.static.save_vars(
                executor, "./", program, vars=vars, filename=dirname
            )

    def _save_sparse_params(
        self, executor, dirname, context, main_program, mode
    ):
        distributed_varnames = get_sparse_tablenames(
            self.origin_main_programs, True
        )
Z
ziyoujiyi 已提交
1466
        values = []
W
wangguanqun 已提交
1467
        model_path = self._get_inference_model_path(dirname)
Z
ziyoujiyi 已提交
1468 1469 1470 1471
        for id, names in context.items():
            if names[0] not in distributed_varnames:
                # only save sparse param to local
                try:
W
wangguanqun 已提交
1472
                    self._worker.recv_and_save_model(id, model_path)
Z
ziyoujiyi 已提交
1473 1474 1475 1476 1477 1478 1479 1480
                except:
                    pass
            # save sparse & distributed param on server
            self._worker.save_one_model(id, dirname, mode)
            values.extend(names)
        # self._worker.save_all_model(dirname, mode)
        return values

1481 1482 1483
    def _save_distributed_persistables(
        self, executor, dirname, main_program=None, mode=0, **kwargs
    ):
Z
ziyoujiyi 已提交
1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501
        """
        This function filters out all variables with `persistable==True` from the
        give `main_program` and then saves these variables to the folder `dirname`
        or file `filename`.

        The `dirname` is used to specify the folder where persistable variables
        are going to be saved. If you would like to save variables in separate
        files, set `filename` None; if you would like to save all variables in a
        single file, use `filename` to specify the file name.
        """

        if isinstance(executor, ParallelExecutor):
            raise TypeError(
                "in fleet.save() function, executor must be as Executor type, ParallelExecutor is not allowed"
            )

        if not isinstance(executor, Executor):
            raise TypeError(
1502 1503
                "in fleet.save() function, executor must be as Executor type"
            )
Z
ziyoujiyi 已提交
1504 1505

        if main_program is None:
1506
            main_program = self.context['origin_main_program']
Z
ziyoujiyi 已提交
1507 1508 1509 1510 1511 1512 1513 1514

        if isinstance(main_program, CompiledProgram):
            raise TypeError(
                "in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed"
            )

        self._worker.save_all_model(dirname, mode)

1515 1516 1517 1518 1519 1520 1521 1522 1523 1524
    def _ps_inference_save_inference_model(
        self,
        executor,
        dirname,
        feeded_var_names,
        target_vars,
        main_program=None,
        export_for_deployment=True,
        mode=0,
    ):
Z
ziyoujiyi 已提交
1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536
        """
        Prune the given `main_program` to build a new program especially for inference,
        and then save it and all related parameters to given `dirname` by the `executor`.
        """

        if isinstance(executor, ParallelExecutor):
            raise TypeError(
                "in fleet.save() function, executor must be as Executor type, ParallelExecutor is not allowed"
            )

        if not isinstance(executor, Executor):
            raise TypeError(
1537 1538
                "in fleet.save() function, executor must be as Executor type"
            )
Z
ziyoujiyi 已提交
1539 1540

        import paddle
1541 1542 1543 1544 1545 1546

        program = (
            self.origin_main_programs[0]
            if main_program is None
            else main_program
        )
1547 1548 1549
        _, _, idx = get_program_by_id(self.context, id(program))
        scope = self.scopes[idx]
        print("save inference model scope idx:", idx)
Z
ziyoujiyi 已提交
1550 1551 1552 1553 1554 1555 1556 1557 1558 1559

        if isinstance(program, CompiledProgram):
            raise TypeError(
                "in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed"
            )

        feed_vars = [
            program.global_block().var(name) for name in feeded_var_names
        ]

1560 1561 1562
        infer_program = paddle.static.normalize_program(
            program, feed_vars, target_vars
        )
Z
ziyoujiyi 已提交
1563 1564 1565

        infer_program._copy_dist_param_info_from(program)

W
wangguanqun 已提交
1566
        model_path = self._get_inference_model_path(dirname)
Z
ziyoujiyi 已提交
1567 1568 1569 1570 1571 1572 1573
        model_basename = "__model__"
        model_basename = os.path.join(model_path, model_basename)
        paddle.save(infer_program, model_basename)

        sparses = get_the_one_recv_context(
            self.context,
            is_dense=False,
1574 1575 1576 1577 1578
            split_dense_table=self.is_heter_ps_mode,
        )
        sparse_names = self._save_sparse_params(
            executor, dirname, sparses, main_program, mode
        )
Z
ziyoujiyi 已提交
1579

1580
        dense_map = get_the_one_recv_context(
1581 1582
            self.context, split_dense_table=self.is_heter_ps_mode
        )
1583
        send_ctx = get_the_one_send_context(
Z
ziyoujiyi 已提交
1584 1585
            self.context,
            split_dense_table=self.is_heter_ps_mode,
1586 1587
            ep_list=self.endpoints,
        )
1588
        self._pull_dense(program, scope, send_ctx, dense_map)
Z
ziyoujiyi 已提交
1589 1590

        generate_vars = self.context[
1591 1592
            "user_defined_strategy"
        ].trainer_desc_configs["stat_var_names"]
Z
ziyoujiyi 已提交
1593 1594
        generate_vars = [var for var in generate_vars]
        remaining_vars = list(
1595 1596 1597 1598 1599
            filter(
                TheOnePSRuntime.__exclude_vars(sparse_names),
                infer_program.list_vars(),
            )
        )
Z
ziyoujiyi 已提交
1600 1601

        for var in remaining_vars:
1602
            tensor = var.get_value(scope)
1603 1604 1605 1606 1607
            paddle.save(
                tensor,
                os.path.join(model_path, var.name),
                use_binary_format=True,
            )
Z
ziyoujiyi 已提交
1608

Z
zhaocaibei123 已提交
1609
    def _save_cache_model(self, dirname, **kwargs):
1610
        mode = kwargs.get("mode", 1)
Z
zhaocaibei123 已提交
1611 1612 1613 1614 1615 1616 1617
        table_id = kwargs.get("table_id", 0)
        self._worker.client_flush()
        fleet.util.barrier()
        cache_threshold = 0.0

        if self.role_maker._is_first_worker():
            cache_threshold = self._worker.get_cache_threshold(table_id)
1618
        # check cache threshold right or not
Z
zhaocaibei123 已提交
1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632
        fleet.util.barrier()

        if self.role_maker._is_first_worker():
            self._worker.cache_shuffle(table_id, dirname, mode, cache_threshold)

        fleet.util.barrier()

        feasign_num = -1
        if self.role_maker._is_first_worker():
            feasign_num = self._worker.save_cache(table_id, dirname, mode)

        fleet.util.barrier()
        return feasign_num

1633 1634 1635 1636 1637 1638
    def _check_save_pre_patch_done(self):
        fleet.util.barrier()
        if self.role_maker._is_first_worker():
            self._worker.check_save_pre_patch_done()
        fleet.util.barrier()

Z
ziyoujiyi 已提交
1639
    def _load_sparse_params(self, dirname, context, main_program, mode):
1640 1641 1642
        distributed_varnames = get_sparse_tablenames(
            self.origin_main_programs, True
        )
Z
ziyoujiyi 已提交
1643 1644 1645 1646 1647 1648 1649 1650 1651 1652
        values = []
        for id, names in context.items():
            if names[0] not in distributed_varnames:
                # TODO: only load sparse param from local
                warnings.warn("varname is not in distributed_varnames, pass")
            # load sparse & distributed param on server
            self._worker.load_one_table(id, dirname, mode)
            values.extend(names)
        return values

1653 1654 1655 1656 1657 1658 1659 1660
    def _ps_inference_load_inference_model(
        self, dirname, mode=0, main_program=None
    ):
        main_program = (
            self.origin_main_programs[0]
            if main_program is None
            else main_program
        )
1661 1662 1663
        _, _, idx = get_program_by_id(self.context, id(main_program))
        scope = self.scopes[idx]
        print("load inference model scope idx:", idx)
Z
ziyoujiyi 已提交
1664 1665 1666 1667 1668 1669 1670 1671 1672

        if isinstance(main_program, CompiledProgram):
            raise TypeError(
                "in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed"
            )

        sparses = get_the_one_recv_context(
            self.context,
            is_dense=False,
1673 1674
            split_dense_table=self.is_heter_ps_mode,
        )
Z
ziyoujiyi 已提交
1675

1676 1677 1678
        sparse_varnames = self._load_sparse_params(
            dirname, sparses, main_program, mode
        )
Z
ziyoujiyi 已提交
1679

1680
        dense_map = get_the_one_recv_context(
1681 1682
            self.context, split_dense_table=self.is_heter_ps_mode
        )
1683 1684 1685
        send_ctx = get_the_one_send_context(
            self.context,
            split_dense_table=self.is_heter_ps_mode,
1686 1687
            ep_list=self.endpoints,
        )
1688

Z
ziyoujiyi 已提交
1689
        recv_dense_varnames = []
1690
        for _, names in dense_map.items():
Z
ziyoujiyi 已提交
1691 1692 1693 1694 1695
            recv_dense_varnames.extend(names)

        loaded_varnames = sparse_varnames

        remaining_vars = list(
1696 1697 1698 1699 1700
            filter(
                TheOnePSRuntime.__exclude_vars(loaded_varnames),
                main_program.list_vars(),
            )
        )
Z
ziyoujiyi 已提交
1701

1702
        model_path = self._get_inference_model_path(dirname)
Z
ziyoujiyi 已提交
1703
        import paddle
1704

Z
ziyoujiyi 已提交
1705 1706 1707 1708
        for var in remaining_vars:
            if var.name not in recv_dense_varnames:
                continue
            tensor = paddle.load(os.path.join(model_path, var.name))
1709
            var.set_value(tensor, scope)
Z
ziyoujiyi 已提交
1710

1711
        self._init_params(main_program, scope, send_ctx, dense_map)
Z
ziyoujiyi 已提交
1712

1713
    def _save_one_table(self, table_id, path, mode):
1714
        fleet.util.barrier()
1715 1716 1717
        if self.role_maker._is_first_worker():
            self._worker.save_one_model(table_id, path, mode)
        fleet.util.barrier()
Z
ziyoujiyi 已提交
1718

1719
    def _save_dense_params(self, *args, **kwargs):
1720
        fleet.util.barrier()
1721 1722 1723 1724 1725
        if self.role_maker._is_first_worker():
            self._ps_save_dense_params(*args, **kwargs)
        fleet.util.barrier()

    def _save_persistables(self, *args, **kwargs):
1726
        fleet.util.barrier()
1727 1728 1729 1730 1731
        if self.role_maker._is_first_worker():
            self._save_distributed_persistables(*args, **kwargs)
        fleet.util.barrier()

    def _save_inference_model(self, *args, **kwargs):
1732
        fleet.util.barrier()
1733 1734 1735 1736 1737
        if self.role_maker._is_first_worker():
            self._ps_inference_save_inference_model(*args, **kwargs)
        fleet.util.barrier()

    def _load_one_table(self, table_id, path, mode):
1738
        fleet.util.barrier()
1739 1740 1741 1742 1743
        if self.role_maker._is_first_worker():
            self._worker.load_one_table(table_id, path, mode)
        fleet.util.barrier()

    def _load_persistables(self, path, mode):
1744
        fleet.util.barrier()
1745 1746 1747 1748 1749
        if self.role_maker._is_first_worker():
            self._worker.load_model(path, mode)
        fleet.util.barrier()

    def _load_inference_model(self, path, mode):
1750
        fleet.util.barrier()
1751
        if self.role_maker._is_first_worker():
Z
ziyoujiyi 已提交
1752
            self._ps_inference_load_inference_model(path, mode)
1753
        fleet.util.barrier()
Z
ziyoujiyi 已提交
1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764

    def _shrink(self, threshold=None):
        if threshold is not None:
            warnings.warn(
                "The param threshold is not used in MemorySparseTable, if you need to shrink, please set the config of accessor"
            )
        else:
            threshold = 0

        fleet.util.barrier()
        if self.role_maker._is_first_worker():
Z
ziyoujiyi 已提交
1765
            sparses = get_the_one_recv_context(
Z
ziyoujiyi 已提交
1766 1767
                self.context,
                is_dense=False,
1768 1769
                split_dense_table=self.role_maker._is_heter_parameter_server_mode,
            )
Z
ziyoujiyi 已提交
1770 1771 1772 1773

            for id, names in sparses.items():
                self._worker.shrink_sparse_table(id, threshold)
        fleet.util.barrier()