node.py 31.6 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
13
"""Defination of Server and Worker."""
D
dongdaxiang 已提交
14

15
from . import ps_pb2 as pslib
16

17 18
# NOTE: reduce removed in fuctools in python3
from functools import reduce
D
dongdaxiang 已提交
19 20


21
class Server:
D
dongdaxiang 已提交
22
    """
23 24
    A Server basic class
    it's a base class, does not have implementation
D
dongdaxiang 已提交
25 26 27 28 29 30
    """

    def __init__(self):
        pass


31
class Worker:
D
dongdaxiang 已提交
32
    """
33 34
    A Worker basic class.
    it's a base class, does not have implementation
D
dongdaxiang 已提交
35 36 37 38 39 40 41 42
    """

    def __init__(self):
        pass


class DownpourServer(Server):
    """
43 44 45 46 47
    DownpourServer class is used to generate server program_desc
    Args:
        server: it is pslib.ServerParameter()
    Examples:
        server = DownpourServer()
D
dongdaxiang 已提交
48 49 50
    """

    def __init__(self):
D
dongdaxiang 已提交
51
        self._server = pslib.ServerParameter()
52 53 54 55 56 57 58 59 60
        self._server.downpour_server_param.service_param.server_class = (
            "DownpourBrpcPsServer"
        )
        self._server.downpour_server_param.service_param.client_class = (
            "DownpourBrpcPsClient"
        )
        self._server.downpour_server_param.service_param.service_class = (
            "DownpourPsService"
        )
D
dongdaxiang 已提交
61 62
        self._server.downpour_server_param.service_param.start_server_port = 0
        self._server.downpour_server_param.service_param.server_thread_num = 12
D
dongdaxiang 已提交
63

64
    def add_sparse_table(self, table_id, strategy):
D
dongdaxiang 已提交
65 66 67
        """
        Args:
            table_id(int): id of sparse params table
68
            strategy(dict): the config dict.
D
dongdaxiang 已提交
69
        Returns:
70
            return None
D
dongdaxiang 已提交
71
        """
72

73 74 75 76 77
        for table in self._server.downpour_server_param.downpour_table_param:
            if table.table_id == table_id:
                if table.type == pslib.PS_SPARSE_TABLE:
                    return
                else:
78 79 80 81
                    raise ValueError(
                        "expect table %s type=%s, but actual type=%s"
                        % (table_id, pslib.PS_SPARSE_TABLE, table.type)
                    )
82 83
        if strategy is None:
            strategy = dict()
D
dongdaxiang 已提交
84
        table = self._server.downpour_server_param.downpour_table_param.add()
D
dongdaxiang 已提交
85 86
        table.table_id = table_id
        table.type = pslib.PS_SPARSE_TABLE
87

88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
        support_sparse_key_list = [
            'sparse_table_class',
            'sparse_compress_in_save',
            'sparse_shard_num',
            'sparse_accessor_class',
            'sparse_learning_rate',
            'sparse_initial_g2sum',
            'sparse_initial_range',
            'sparse_weight_bounds',
            'sparse_embedx_dim',
            'sparse_embedx_threshold',
            'sparse_nonclk_coeff',
            'sparse_click_coeff',
            'sparse_base_threshold',
            'sparse_delta_threshold',
            'sparse_delta_keep_days',
            'sparse_delete_after_unseen_days',
            'sparse_show_click_decay_rate',
            'sparse_delete_threshold',
            'sparse_converter',
            'sparse_deconverter',
            'sparse_enable_cache',
            'sparse_cache_rate',
            'sparse_cache_file_num',
            'sparse_beta1_decay_rate',
            'sparse_beta2_decay_rate',
            'sparse_ada_epsilon',
            'sparse_optimizer',
            'sparse_ssd_unseenday_threshold',
            'embed_sparse_optimizer',
            'embed_sparse_learning_rate',
            'embed_sparse_weight_bounds',
            'embed_sparse_initial_range',
            'embed_sparse_initial_g2sum',
            'embed_sparse_beta1_decay_rate',
            'embed_sparse_beta2_decay_rate',
            'embedx_sparse_optimizer',
            'embedx_sparse_learning_rate',
            'embedx_sparse_weight_bounds',
            'embedx_sparse_initial_range',
            'embedx_sparse_initial_g2sum',
            'embedx_sparse_beta1_decay_rate',
            'embedx_sparse_beta2_decay_rate',
        ]
132 133 134 135 136

        for key in strategy:
            if key not in support_sparse_key_list:
                raise ValueError("strategy key '%s' not support" % (key))

137
        support_table_calss = ['DownpourSparseTable', 'DownpourSparseSSDTable']
138 139 140 141
        if strategy.get('sparse_table_class') is not None:
            table_class = strategy.get('sparse_table_class')
            if table_class not in support_table_calss:
                raise ValueError(
142
                    "support sparse_table_class: [ 'DownpourSparseTable', 'DownpourSparseSSDTable'], \
143 144 145
                        but actual %s"
                    % (table_class)
                )
146 147 148 149 150
        else:
            table_class = 'DownpourSparseTable'

        table.table_class = table_class

151 152 153 154
        if (
            table_class == 'DownpourSparseTable'
            or table_class == 'DownpourSparseSSDTable'
        ):
155
            table.enable_sparse_table_cache = strategy.get(
156 157
                'sparse_enable_cache', True
            )
158
            table.sparse_table_cache_rate = strategy.get(
159 160
                'sparse_cache_rate', 0.00055
            )
161
            table.sparse_table_cache_file_num = strategy.get(
162 163 164 165 166
                'sparse_cache_file_num', 16
            )
            table.compress_in_save = strategy.get(
                'sparse_compress_in_save', True
            )
167
            table.shard_num = strategy.get('sparse_shard_num', 1000)
168 169 170
            # DownpourFeatureValueAccessor: for ctr task, has cvm, embedding and sgd info
            # DownpourCtrAccessor         : for ctr task, has cvm, slot, embedding and sgd info
            # DownpourSparseValueAccessor : for general task, has embedding and sgd info
171
            # DownpourCtrDoubleAccessor   : for ctr task, which show clk are in double
X
xujiaqi01 已提交
172
            # DownpourUnitAccessor        : for ctr task, has cvm, slot, embedding and sgd info
173 174

            support_accessor_class = [
175 176 177 178 179 180 181
                'DownpourFeatureValueAccessor',
                'DownpourCtrAccessor',
                'DownpourCtrDymfAccessor',
                'DownpourSparseValueAccessor',
                'DownpourCtrDoubleAccessor',
                'DownpourUnitAccessor',
                'DownpourDoubleUnitAccessor',
182 183 184 185 186
            ]
            if strategy.get('sparse_accessor_class') is not None:
                accessor_class = strategy.get('sparse_accessor_class')
                if accessor_class not in support_accessor_class:
                    raise ValueError(
Y
yaoxuefeng 已提交
187
                        "support sparse_accessor_class: ['DownpourFeatureValueAccessor', 'DownpourCtrAccessor', 'DownpourCtrDymfAccessor', \
188
                        'DownpourSparseValueAccessor', 'DownpourCtrDoubleAccessor'], \
189 190 191
                            but actual %s"
                        % (accessor_class)
                    )
192 193 194 195 196
            else:
                accessor_class = 'DownpourCtrAccessor'

            table.accessor.accessor_class = accessor_class

197 198 199 200 201 202
            if (
                accessor_class == 'DownpourFeatureValueAccessor'
                or accessor_class == 'DownpourCtrAccessor'
                or accessor_class == 'DownpourCtrDymfAccessor'
                or accessor_class == 'DownpourCtrDoubleAccessor'
            ):
203
                table.accessor.sparse_sgd_param.learning_rate = strategy.get(
204 205
                    'sparse_learning_rate', 0.05
                )
206
                table.accessor.sparse_sgd_param.initial_g2sum = strategy.get(
207 208
                    'sparse_initial_g2sum', 3
                )
209
                table.accessor.sparse_sgd_param.initial_range = strategy.get(
210 211
                    'sparse_initial_range', 1e-4
                )
212 213
                if strategy.get('sparse_weight_bounds') is None:
                    table.accessor.sparse_sgd_param.weight_bounds.extend(
214 215
                        [-10, 10]
                    )
216 217
                else:
                    table.accessor.sparse_sgd_param.weight_bounds.extend(
218 219
                        strategy.get('sparse_weight_bounds')
                    )
220 221
                table.accessor.embedx_dim = strategy.get('sparse_embedx_dim', 8)
                table.accessor.embedx_threshold = strategy.get(
222 223
                    'sparse_embedx_threshold', 10
                )
224
                table.accessor.fea_dim = int(table.accessor.embedx_dim) + 3
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
                table.accessor.downpour_accessor_param.nonclk_coeff = (
                    strategy.get('sparse_nonclk_coeff', 0.1)
                )
                table.accessor.downpour_accessor_param.click_coeff = (
                    strategy.get('sparse_click_coeff', 1)
                )
                table.accessor.downpour_accessor_param.base_threshold = (
                    strategy.get('sparse_base_threshold', 1.5)
                )
                table.accessor.downpour_accessor_param.delta_threshold = (
                    strategy.get('sparse_delta_threshold', 0.25)
                )
                table.accessor.downpour_accessor_param.delta_keep_days = (
                    strategy.get('sparse_delta_keep_days', 16)
                )
240
                table.accessor.downpour_accessor_param.delete_after_unseen_days = strategy.get(
241 242
                    'sparse_delete_after_unseen_days', 30
                )
243
                table.accessor.downpour_accessor_param.ssd_unseenday_threshold = strategy.get(
244 245 246 247 248 249 250 251
                    'sparse_ssd_unseenday_threshold', 1
                )
                table.accessor.downpour_accessor_param.show_click_decay_rate = (
                    strategy.get('sparse_show_click_decay_rate', 0.98)
                )
                table.accessor.downpour_accessor_param.delete_threshold = (
                    strategy.get('sparse_delete_threshold', 0.8)
                )
252 253
                converter = strategy.get(
                    'sparse_converter',
254 255
                    "(scripts/xbox_compressor_mf.py | bin/xbox_pb_converter)",
                )
256
                deconverter = strategy.get(
257
                    'sparse_deconverter',
258
                    "(bin/xbox_pb_deconverter | scripts/xbox_decompressor_mf.awk)",
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
                )

                table1 = table.accessor.table_accessor_save_param.add()
                table1.param = 1
                table1.converter = converter
                table1.deconverter = deconverter

                table2 = table.accessor.table_accessor_save_param.add()
                table2.param = 2
                table2.converter = converter
                table2.deconverter = deconverter
            elif accessor_class == 'DownpourSparseValueAccessor':
                optimizer_name = strategy.get("sparse_optimizer", "adam")
                table.accessor.sparse_commonsgd_param.name = optimizer_name
                table.accessor.embedx_dim = strategy.get('sparse_embedx_dim', 8)
                table.accessor.fea_dim = int(table.accessor.embedx_dim)
                if optimizer_name == "naive":
276 277 278 279 280 281
                    table.accessor.sparse_commonsgd_param.naive.learning_rate = strategy.get(
                        'sparse_learning_rate', 0.05
                    )
                    table.accessor.sparse_commonsgd_param.naive.initial_range = strategy.get(
                        'sparse_initial_range', 1e-4
                    )
282 283
                    if strategy.get('sparse_weight_bounds') is None:
                        table.accessor.sparse_commonsgd_param.naive.weight_bounds.extend(
284 285
                            [-10, 10]
                        )
286 287
                    else:
                        table.accessor.sparse_commonsgd_param.naive.weight_bounds.extend(
288 289
                            strategy.get('sparse_weight_bounds')
                        )
290
                elif optimizer_name == "adagrad":
291 292 293 294 295 296
                    table.accessor.sparse_commonsgd_param.adagrad.learning_rate = strategy.get(
                        'sparse_learning_rate', 0.05
                    )
                    table.accessor.sparse_commonsgd_param.adagrad.initial_range = strategy.get(
                        'sparse_initial_range', 1e-4
                    )
297
                    table.accessor.sparse_commonsgd_param.adagrad.initial_g2sum = strategy.get(
298 299
                        'sparse_initial_g2sum', 3
                    )
300 301
                    if strategy.get('sparse_weight_bounds') is None:
                        table.accessor.sparse_commonsgd_param.adagrad.weight_bounds.extend(
302 303
                            [-10, 10]
                        )
304 305
                    else:
                        table.accessor.sparse_commonsgd_param.adagrad.weight_bounds.extend(
306 307
                            strategy.get('sparse_weight_bounds')
                        )
308
                elif optimizer_name == "adam":
309
                    table.accessor.sparse_commonsgd_param.adam.learning_rate = (
310
                        strategy.get('sparse_learning_rate', 0.001)
311 312
                    )
                    table.accessor.sparse_commonsgd_param.adam.initial_range = (
313
                        strategy.get('sparse_initial_range', 1e-4)
314
                    )
315
                    table.accessor.sparse_commonsgd_param.adam.beta1_decay_rate = strategy.get(
316 317
                        'sparse_beta1_decay_rate', 0.9
                    )
318
                    table.accessor.sparse_commonsgd_param.adam.beta2_decay_rate = strategy.get(
319 320 321 322 323
                        'sparse_beta2_decay_rate', 0.999
                    )
                    table.accessor.sparse_commonsgd_param.adam.ada_epsilon = (
                        strategy.get('sparse_ada_epsilon', 1e-8)
                    )
324 325
                    if strategy.get('sparse_weight_bounds') is None:
                        table.accessor.sparse_commonsgd_param.adam.weight_bounds.extend(
326 327
                            [-10, 10]
                        )
328 329
                    else:
                        table.accessor.sparse_commonsgd_param.adam.weight_bounds.extend(
330 331
                            strategy.get('sparse_weight_bounds')
                        )
332 333
                converter = strategy.get(
                    'sparse_converter',
334 335
                    "(scripts/xbox_compressor_mf.py | bin/xbox_pb_converter)",
                )
336
                deconverter = strategy.get(
337
                    'sparse_deconverter',
338
                    "(bin/xbox_pb_deconverter | scripts/xbox_decompressor_mf.awk)",
339 340
                )

341 342
                table1 = table.accessor.table_accessor_save_param.add()
                table1.param = 1
343 344 345
                table1.converter = converter
                table1.deconverter = deconverter

346 347
                table2 = table.accessor.table_accessor_save_param.add()
                table2.param = 2
348 349
                table2.converter = converter
                table2.deconverter = deconverter
350 351 352 353
            elif (
                accessor_class == 'DownpourUnitAccessor'
                or accessor_class == 'DownpourDoubleUnitAccessor'
            ):
X
xujiaqi01 已提交
354
                self.add_sparse_table_common_config(table, strategy)
355 356 357 358 359 360
                self.add_sparse_optimizer(
                    table.accessor.embed_sgd_param, strategy, "embed_"
                )
                self.add_sparse_optimizer(
                    table.accessor.embedx_sgd_param, strategy, "embedx_"
                )
361

362 363 364
    def add_dense_table(
        self, table_id, param_var, grad_var, strategy, sparse_table_names
    ):
D
dongdaxiang 已提交
365 366 367
        """
        Args:
            table_id(int): id of sparse params table
368 369 370 371
            param_var(list): param vars
            grad_var(list): param grad vars
            strategy(dict): the dense config dict
            sparse_table_names(list): sparse table names
D
dongdaxiang 已提交
372
        Returns:
373
            return None
D
dongdaxiang 已提交
374
        """
375
        fea_dim = 0
376 377
        dense_param_vars = []
        for p in param_var:
378
            if p.name not in sparse_table_names:
379 380 381
                dense_param_vars.append(p)

        for param in dense_param_vars:
382 383 384 385 386 387 388 389
            fea_dim += reduce(lambda x, y: x * y, param.shape, 1)

        for table in self._server.downpour_server_param.downpour_table_param:
            if table.table_id == table_id:
                if table.type == pslib.PS_DENSE_TABLE:
                    table.accessor.fea_dim = fea_dim
                    return
                else:
390 391 392 393
                    raise ValueError(
                        "expect table %s type=%s, but actual type=%s"
                        % (table_id, pslib.PS_DENSE_TABLE, table.type)
                    )
394 395 396

        if strategy is None:
            strategy = dict()
T
tangwei12 已提交
397
        table = self._server.downpour_server_param.downpour_table_param.add()
D
dongdaxiang 已提交
398
        table.table_id = table_id
399 400 401 402 403 404 405 406 407 408 409 410
        support_dense_key_list = [
            'dense_table_class',
            'dense_compress_in_save',
            'dense_accessor_class',
            'dense_optimizer',
            'dense_learning_rate',
            'dense_avg_decay',
            'dense_ada_decay',
            'dense_ada_epsilon',
            'dense_mom_decay',
            'dense_naive_lr',
        ]
411 412 413 414 415

        for key in strategy:
            if key not in support_dense_key_list:
                raise ValueError("strategy key '%s' not support" % (key))

416 417 418
        table.table_class = strategy.get(
            'dense_table_class', "DownpourDenseTable"
        )
D
dongdaxiang 已提交
419
        table.type = pslib.PS_DENSE_TABLE
420 421
        table.compress_in_save = strategy.get('dense_compress_in_save', True)
        table.accessor.accessor_class = strategy.get(
422 423
            'dense_accessor_class', "DownpourDenseValueAccessor"
        )
424
        table.accessor.dense_sgd_param.name = strategy.get(
425 426
            'dense_optimizer', "adam"
        )
427
        table.accessor.dense_sgd_param.adam.learning_rate = strategy.get(
428 429
            'dense_learning_rate', 5e-06
        )
430
        table.accessor.dense_sgd_param.adam.avg_decay_rate = strategy.get(
431 432
            'dense_avg_decay', 0.999993
        )
433
        table.accessor.dense_sgd_param.adam.ada_decay_rate = strategy.get(
434 435
            'dense_ada_decay', 0.9999
        )
436
        table.accessor.dense_sgd_param.adam.ada_epsilon = strategy.get(
437 438
            'dense_ada_epsilon', 1e-8
        )
439
        table.accessor.dense_sgd_param.adam.mom_decay_rate = strategy.get(
440 441
            'dense_mom_decay', 0.99
        )
442
        table.accessor.dense_sgd_param.naive.learning_rate = strategy.get(
443 444
            'dense_naive_lr', 0.0002
        )
D
dongdaxiang 已提交
445 446
        table.accessor.fea_dim = fea_dim

447 448 449 450 451 452 453 454 455
    def add_data_norm_table(
        self,
        table_id,
        learning_rate,
        param_var,
        grad_var,
        strategy,
        sparse_table_names,
    ):
D
dongdaxiang 已提交
456 457
        """
        Args:
458
            table_id(int): id of datanorm table
459 460 461 462 463
            learning_rate(float): the learning rate used to update parameters
            param_var(list): param vars
            grad_var(list): param grad vars
            strategy(dict): the datanorm config dict
            sparse_table_names(list): sparse table names
D
dongdaxiang 已提交
464
        Returns:
465
            return None
D
dongdaxiang 已提交
466
        """
467
        fea_dim = 0
468 469
        dense_param_vars = []
        for p in param_var:
470
            if p.name not in sparse_table_names:
471 472 473
                dense_param_vars.append(p)

        for param in dense_param_vars:
474 475 476 477 478 479 480 481
            fea_dim += reduce(lambda x, y: x * y, param.shape, 1)

        for table in self._server.downpour_server_param.downpour_table_param:
            if table.table_id == table_id:
                if table.type == pslib.PS_DENSE_TABLE:
                    table.accessor.fea_dim = fea_dim
                    return
                else:
482 483 484 485
                    raise ValueError(
                        "expect table %s type=%s, but actual type=%s"
                        % (table_id, pslib.PS_DENSE_TABLE, table.type)
                    )
486 487 488
        if strategy is None:
            strategy = dict()

489 490 491 492 493 494 495
        support_datanorm_key_list = [
            'datanorm_table_class',
            'datanorm_compress_in_save',
            'datanorm_accessor_class',
            'datanorm_operation',
            'datanorm_decay_rate',
        ]
496 497 498 499 500

        for key in strategy:
            if key not in support_datanorm_key_list:
                raise ValueError("strategy key '%s' not support" % (key))

D
dongdaxiang 已提交
501
        table = self._server.downpour_server_param.downpour_table_param.add()
D
dongdaxiang 已提交
502
        table.table_id = table_id
503 504 505
        table.table_class = strategy.get(
            'datanorm_table_class', 'DownpourDenseTable'
        )
D
dongdaxiang 已提交
506
        table.type = pslib.PS_DENSE_TABLE
507 508
        table.compress_in_save = strategy.get('datanorm_compress_in_save', True)
        table.accessor.accessor_class = strategy.get(
509 510
            'datanorm_accessor_class', 'DownpourDenseValueAccessor'
        )
511
        table.accessor.dense_sgd_param.name = strategy.get(
512 513 514 515 516
            'datanorm_operation', 'summary'
        )
        table.accessor.dense_sgd_param.summary.summary_decay_rate = (
            strategy.get('datanorm_decay_rate', 0.999999)
        )
D
dongdaxiang 已提交
517 518
        table.accessor.fea_dim = fea_dim

X
xujiaqi01 已提交
519
    def add_sparse_optimizer(self, sgd, strategy, prefix):
T
Thunderbrook 已提交
520
        optimizer_name = strategy.get(prefix + "sparse_optimizer", "adagrad")
X
xujiaqi01 已提交
521 522
        sgd.name = optimizer_name
        if optimizer_name == "naive":
523 524 525 526 527 528
            sgd.naive.learning_rate = strategy.get(
                prefix + 'sparse_learning_rate', 0.05
            )
            sgd.naive.initial_range = strategy.get(
                prefix + 'sparse_initial_range', 1e-4
            )
X
xujiaqi01 已提交
529 530 531
            bounds = strategy.get(prefix + 'sparse_weight_bounds', [-10, 10])
            sgd.naive.weight_bounds.extend(bounds)
        elif optimizer_name == "adagrad":
532 533 534 535 536 537
            sgd.adagrad.learning_rate = strategy.get(
                prefix + 'sparse_learning_rate', 0.05
            )
            sgd.adagrad.initial_range = strategy.get(
                prefix + 'sparse_initial_range', 1e-4
            )
T
Thunderbrook 已提交
538 539 540
            if prefix == "embed_":
                sgd.adagrad.initial_range = 0
            sgd.adagrad.initial_g2sum = strategy.get(
541 542
                prefix + 'sparse_initial_g2sum', 3
            )
T
Thunderbrook 已提交
543 544 545
            bounds = strategy.get(prefix + 'sparse_weight_bounds', [-10, 10])
            sgd.adagrad.weight_bounds.extend(bounds)
        elif optimizer_name == "std_adagrad":
546 547 548 549 550 551
            sgd.adagrad.learning_rate = strategy.get(
                prefix + 'sparse_learning_rate', 0.05
            )
            sgd.adagrad.initial_range = strategy.get(
                prefix + 'sparse_initial_range', 1e-4
            )
T
Thunderbrook 已提交
552 553
            if prefix == "embed_":
                sgd.adagrad.initial_range = 0
X
xujiaqi01 已提交
554
            sgd.adagrad.initial_g2sum = strategy.get(
555 556
                prefix + 'sparse_initial_g2sum', 3
            )
X
xujiaqi01 已提交
557 558 559
            bounds = strategy.get(prefix + 'sparse_weight_bounds', [-10, 10])
            sgd.adagrad.weight_bounds.extend(bounds)
        elif optimizer_name == "adam":
560 561 562 563 564 565
            sgd.adam.learning_rate = strategy.get(
                prefix + 'sparse_learning_rate', 0.001
            )
            sgd.adam.initial_range = strategy.get(
                prefix + 'sparse_initial_range', 1e-4
            )
X
xujiaqi01 已提交
566
            sgd.adam.beta1_decay_rate = strategy.get(
567 568
                prefix + 'sparse_beta1_decay_rate', 0.9
            )
X
xujiaqi01 已提交
569
            sgd.adam.beta2_decay_rate = strategy.get(
570 571 572 573 574
                prefix + 'sparse_beta2_decay_rate', 0.999
            )
            sgd.adam.ada_epsilon = strategy.get(
                prefix + 'sparse_ada_epsilon', 1e-8
            )
X
xujiaqi01 已提交
575 576 577 578 579 580
            bounds = strategy.get(prefix + 'sparse_weight_bounds', [-10, 10])
            sgd.adam.weight_bounds.extend(bounds)

    def add_sparse_table_common_config(self, table, strategy):
        table.accessor.embedx_dim = strategy.get('sparse_embedx_dim', 8)
        table.accessor.embedx_threshold = strategy.get(
581 582
            'sparse_embedx_threshold', 10
        )
X
xujiaqi01 已提交
583 584
        table.accessor.fea_dim = int(table.accessor.embedx_dim) + 3
        table.accessor.downpour_accessor_param.nonclk_coeff = strategy.get(
585 586
            'sparse_nonclk_coeff', 0.1
        )
X
xujiaqi01 已提交
587
        table.accessor.downpour_accessor_param.click_coeff = strategy.get(
588 589
            'sparse_click_coeff', 1
        )
X
xujiaqi01 已提交
590
        table.accessor.downpour_accessor_param.base_threshold = strategy.get(
591 592
            'sparse_base_threshold', 1.5
        )
X
xujiaqi01 已提交
593
        table.accessor.downpour_accessor_param.delta_threshold = strategy.get(
594 595
            'sparse_delta_threshold', 0.25
        )
X
xujiaqi01 已提交
596
        table.accessor.downpour_accessor_param.delta_keep_days = strategy.get(
597 598 599 600 601 602 603 604
            'sparse_delta_keep_days', 16
        )
        table.accessor.downpour_accessor_param.delete_after_unseen_days = (
            strategy.get('sparse_delete_after_unseen_days', 30)
        )
        table.accessor.downpour_accessor_param.show_click_decay_rate = (
            strategy.get('sparse_show_click_decay_rate', 0.98)
        )
X
xujiaqi01 已提交
605
        table.accessor.downpour_accessor_param.delete_threshold = strategy.get(
606 607
            'sparse_delete_threshold', 0.8
        )
X
xujiaqi01 已提交
608 609
        converter = strategy.get(
            'sparse_converter',
610 611
            "(scripts/xbox_compressor_mf.py | bin/xbox_pb_converter)",
        )
X
xujiaqi01 已提交
612 613
        deconverter = strategy.get(
            'sparse_deconverter',
614 615
            "(bin/xbox_pb_deconverter | scripts/xbox_decompressor_mf.awk)",
        )
X
xujiaqi01 已提交
616 617 618 619 620 621 622 623 624 625 626

        table1 = table.accessor.table_accessor_save_param.add()
        table1.param = 1
        table1.converter = converter
        table1.deconverter = deconverter

        table2 = table.accessor.table_accessor_save_param.add()
        table2.param = 2
        table2.converter = converter
        table2.deconverter = deconverter

D
dongdaxiang 已提交
627 628 629 630
    def get_desc(self):
        """
        Return downpour server program_desc
        """
D
dongdaxiang 已提交
631
        return self._server
D
dongdaxiang 已提交
632 633 634 635


class DownpourWorker(Worker):
    """
636 637 638 639 640 641
    DownpourWorker class is used to generate worker program_desc
    Args:
        window (int): push params frequency
        worker: it is pslib.DownpourTrainerParameter
    Examples:
        worker = DownpourWorker(1)
D
dongdaxiang 已提交
642 643 644 645
    """

    def __init__(self, window):
        self.window = window
D
dongdaxiang 已提交
646
        self._worker = pslib.DownpourTrainerParameter()
D
dongdaxiang 已提交
647

648 649 650
    def add_sparse_table(
        self, table_id, slot_key_vars, slot_value_vars, slot_value_grads=None
    ):
D
dongdaxiang 已提交
651 652 653
        """
        Args:
            table_id(int): id of sparse params table
654 655 656
            slot_key_vars(list): slot key id
            slot_value_vars(list): slot key value after embedding
            slot_value_grads(list): grad of all params, default is None
D
dongdaxiang 已提交
657
        Returns:
658
            return None
D
dongdaxiang 已提交
659
        """
660
        if slot_value_grads is None:
661 662 663
            slot_value_grad_names = [
                var.name + "@GRAD" for var in slot_value_vars
            ]
664 665 666 667 668 669 670 671 672
        else:
            value_to_key = {}
            for i in range(len(slot_key_vars)):
                value_to_key[slot_value_vars[i].name] = slot_key_vars[i]
            slot_value_grad_names = []
            all_grad_names = [var.name for var in slot_value_grads]
            for var in slot_value_vars:
                if var.name + "@GRAD" in all_grad_names:
                    slot_value_grad_names.append(var.name + "@GRAD")
673 674 675 676 677 678 679 680 681 682 683 684 685
            sorted_slot_value_vars = [
                i
                for i in slot_value_vars
                if i.name + "@GRAD" in slot_value_grad_names
            ]
            sorted_slot_value_vars += [
                i
                for i in slot_value_vars
                if i.name + "@GRAD" not in slot_value_grad_names
            ]
            sorted_slot_key_vars = [
                value_to_key[v.name] for v in sorted_slot_value_vars
            ]
686 687

        target_table = None
688 689
        for table in self._worker.sparse_table:
            if table.table_id == table_id:
X
xujiaqi01 已提交
690
                keys = table.slot_key
691 692 693
                key_names = [var.name for var in sorted_slot_key_vars]
                for key_name in key_names:
                    if key_name not in keys:
694 695 696
                        raise ValueError(
                            "sparse table %s slot_key error" % table_id
                        )
697 698
                target_table = table
                break
699

700 701 702
        table = target_table
        if table is not None:
            self._worker.sparse_table.remove(table)
T
tangwei12 已提交
703
        table = self._worker.sparse_table.add()
D
dongdaxiang 已提交
704
        table.table_id = table_id
705 706 707
        table.slot_key.extend([var.name for var in sorted_slot_key_vars])
        table.slot_value.extend([var.name for var in sorted_slot_value_vars])
        table.slot_gradient.extend(slot_value_grad_names)
D
dongdaxiang 已提交
708

709 710 711 712 713 714 715 716 717
    def add_dense_table(
        self,
        table_id,
        learning_rate,
        param_vars,
        grad_vars,
        dense_start_table_id,
        sparse_table_names,
    ):
718
        r"""
D
dongdaxiang 已提交
719 720 721 722
        Args:
            table_id(int): id of sparse params table
            learning_rate(float): the learning rate used to update parameters. \
                Can be a float value
723 724 725 726
            param_vars(list): all dense param. it is a list.
            grad_vars(list): all dense grad parm it is a list.
            dense_start_table_id(int): dense table start index
            sparse_table_names(list): sparse table names
D
dongdaxiang 已提交
727
        Returns:
728
            return None
D
dongdaxiang 已提交
729
        """
730
        sparse_table_name_grad = []
731
        for name in sparse_table_names:
732 733 734 735
            sparse_table_name_grad.append(name + "@GRAD")

        dense_param_name = []
        for p in param_vars:
736
            if p.name not in sparse_table_names:
737 738 739 740 741 742 743 744 745
                dense_param_name.append(p.name)

        dense_grad_name = []
        for g in grad_vars:
            if g.name not in sparse_table_name_grad:
                dense_grad_name.append(g.name)

        dense_param_name.sort()
        dense_grad_name.sort()
746

747 748
        for table in self._worker.dense_table:
            if table.table_id == table_id:
749
                desc_dense_param_name = list(table.dense_variable_name)
750 751 752
                desc_dense_param_name.sort()

                if dense_param_name == desc_dense_param_name:
753
                    desc_dense_grad_name = list(
754 755
                        table.dense_gradient_variable_name
                    )
756 757
                    desc_dense_grad_name.sort()
                    if dense_grad_name == desc_dense_grad_name:
758 759 760
                        return
                    else:
                        raise ValueError(
761
                            "dense table %s dense_gradient_variable_name "
762 763
                            "error" % table_id
                        )
764 765
                else:
                    raise ValueError(
766 767
                        "dense table %s dense_variable_name error" % table_id
                    )
768

D
dongdaxiang 已提交
769
        table = self._worker.dense_table.add()
D
dongdaxiang 已提交
770
        table.table_id = table_id
771

772
        # def cmp_fc(x, y):
773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792
        #    if x.startswith("fc_") and y.startswith("fc_"):
        #        index_x = x.find('.')
        #        index_y = y.find('.')
        #        if index_x > 0 and index_y > 0:
        #            num_x = x[3:index_x]
        #            num_y = y[3:index_y]
        #            if num_x.isdigit() and num_y.isdigit():
        #                if int(num_x) < int(num_y):
        #                    return -1
        #                if int(num_x) > int(num_y):
        #                    return 1
        #                if x[index_x + 1] == 'w' and y[index_y + 1] == 'b':
        #                    return -1
        #                if x[index_x + 1] == 'b' and y[index_y + 1] == 'w':
        #                    return 1
        #    if x < y:
        #        return -1
        #    else:
        #        return 1

793 794
        # table.dense_variable_name.extend(sorted(dense_param_name, cmp_fc))
        # table.dense_gradient_variable_name.extend(
795 796 797
        #    sorted(dense_grad_name, cmp_fc))
        table.dense_variable_name.extend(dense_param_name)
        table.dense_gradient_variable_name.extend(dense_grad_name)
D
dongdaxiang 已提交
798 799 800 801 802

    def get_desc(self):
        """
        Return downpour worker program_desc
        """
D
dongdaxiang 已提交
803
        return self._worker