one_hot_encoding.py 3.5 KB
Newer Older
H
He, Kai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module provide one hot encoding tools, implemented by OT (Oblivious Transfer)-based
PSI (Private Set Intersection) algorithm.
"""
from multiprocessing.connection import Client, Listener
import mpc_data_utils as mdu

__all__ = ['one_hot_encoding_map', ]


def one_hot_encoding_map(input_set, host_addr, is_client=True):
    """
    A protocol to get agreement between 2 parties for encoding one
    discrete feature to one hot vector via OT-PSI.

    Args:
        input_set (set:str): The set of possible feature value owned by this
        party. Element of set is str, convert before pass in.

        host_addr (str): The info of host_addr,e.g., ip:port

        is_receiver (bool): True if this party plays as socket client
        otherwise, plays as socket server

    Return Val: dict, int.
H
He,Kai 已提交
40 41
        dict key: feature values in input_set,
        dict value: corresponding idx in one hot vector.
H
He, Kai 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59

        int: length of one hot vector for this feature.

    Examples:
        .. code-block:: python

            import paddle_fl.mpc.data_utils
            import sys

            is_client = sys.argv[1] == "1"

            a = set([str(x) for x in range(7)])
            b = set([str(x) for x in range(5, 10)])

            addr = "127.0.0.1:33784"

            ins = a if is_client else b

H
He,Kai 已提交
60
            x, y = paddle_fl.mpc.data_utils.one_hot_encoding_map(ins, addr, is_client)
H
He, Kai 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
            # y = 10
            # x['5'] = 0, x['6'] = 1
            # for those feature val owned only by one party, dict val shall
            not be conflicting.
            print(x, y)
    """

    ip = host_addr.split(":")[0]
    port = int(host_addr.split(":")[1])

    if is_client:
        intersection = input_set
        intersection = mdu.recv_psi(ip, port, intersection)
        intersection = sorted(list(intersection))
        # Only the receiver can obtain the result.
        # Send result to other parties.
    else:
        ret_code = mdu.send_psi(port, input_set)
        if ret_code != 0:
            raise RuntimeError("Errors occurred in PSI send lib, "
                               "error code = {}".format(ret_code))

    if not is_client:
        server = Listener((ip, port))
    conn = Client((ip, port)) if is_client else server.accept()


    if is_client:
        conn.send(intersection)

        diff_size_local = len(input_set) - len(intersection)
        conn.send(diff_size_local)
        diff_size_remote = conn.recv()

    else:
        intersection = conn.recv()

        diff_size_local = len(input_set) - len(intersection)

        diff_size_remote = conn.recv()
        conn.send(diff_size_local)

    conn.close()
    if not is_client:
        server.close()

    ret = dict()

    cnt = 0

    for x in intersection:
        ret[x] = cnt
        cnt += 1
    if is_client:
        cnt += diff_size_remote
    for x in [x for x in input_set if x not in intersection]:
        ret[x] = cnt
        cnt += 1

    return ret, len(intersection) + diff_size_local + diff_size_remote