test_fleet_pyramid_hash.py 2.6 KB
Newer Older
C
Chengmo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17
import paddle
C
Chengmo 已提交
18 19
import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
W
wangzhen38 已提交
20 21
from paddle.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import (
22 23
    StrategyFactory,
)
C
Chengmo 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46


class TestPyramidHashOpApi(unittest.TestCase):
    def test_dist_geo_server_transpiler(self):
        num_voc = 128
        embed_dim = 64
        x_shape, x_lod = [16, 10], [[3, 5, 2, 6]]
        x = fluid.data(name='x', shape=x_shape, dtype='int32', lod_level=1)
        hash_embd = fluid.contrib.layers.search_pyramid_hash(
            input=x,
            num_emb=embed_dim,
            space_len=num_voc * embed_dim,
            pyramid_layer=4,
            rand_len=16,
            drop_out_percent=0.5,
            is_training=True,
            use_filter=False,
            white_list_len=6400,
            black_list_len=2800,
            seed=3,
            lr=0.002,
            param_attr=fluid.ParamAttr(
                name="PyramidHash_emb_0",
47 48
                learning_rate=0,
            ),
C
Chengmo 已提交
49 50
            param_attr_wl=fluid.ParamAttr(
                name="Filter",
51 52
                learning_rate=0,
            ),
C
Chengmo 已提交
53 54
            param_attr_bl=None,
            distribute_update_vars=["PyramidHash_emb_0"],
55 56
            name=None,
        )
C
Chengmo 已提交
57

58
        cost = paddle.sum(hash_embd)
C
Chengmo 已提交
59 60 61 62 63

        role = role_maker.UserDefinedRoleMaker(
            current_id=0,
            role=role_maker.Role.SERVER,
            worker_num=2,
64 65
            server_endpoints=["127.0.0.1:36011", "127.0.0.1:36012"],
        )
C
Chengmo 已提交
66 67 68

        fleet.init(role)

69
        strategy = StrategyFactory.create_geo_strategy(5)
C
Chengmo 已提交
70 71 72 73 74 75 76 77 78 79
        optimizer = fluid.optimizer.SGD(0.1)
        optimizer = fleet.distributed_optimizer(optimizer, strategy)
        optimizer.minimize(cost)

        pserver_startup_program = fleet.startup_program
        pserver_mian_program = fleet.main_program


if __name__ == "__main__":
    unittest.main()