alignment.py 3.9 KB
Newer Older
J
jhjiangcs 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
J
jingqinghe 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module provide data alignment tools, implemented by OT (Oblivious Transfer)-based
PSI (Private Set Intersection) algorithm.
"""
J
jhjiangcs 已提交
18
from multiprocessing.connection import Client, Listener
J
jingqinghe 已提交
19 20 21
import os
import sys
import mpc_data_utils as mdu
J
jhjiangcs 已提交
22

J
jingqinghe 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
__all__ = ['align', ]


def align(input_set, party_id, endpoints, is_receiver=True):
    """
    Align the data owned by each data party.
    :param input_set: set. The id set of input data owned by this
    party.

    :param party_id: int. The id of this data party, which is
    natural number named from 0.
    :param endpoints: str. The info of all data parties,e.g.,
    id1:ip1:port1,id2:ip2:port2
    :param is_receiver: bool. True if this data party is a receiver
    role among all parties. Note that there is only one receiver
    who can obtain the result of aligning and then send it to other
    parties.
    :return: set. The intersection of data id.
    """
    all_parties = endpoints.split(",")
    _party_idx = _find_party_idx(party_id, all_parties)
    if _party_idx < 0:
        raise RuntimeError("Could not find endpoint with id: {}".format(
            party_id))

    if is_receiver:
        del (all_parties[_party_idx])
        senders = all_parties
        result = input_set
        for sender in senders:
            ip_addr = sender.split(":")[1]
            port = int(sender.split(":")[2])
            result = mdu.recv_psi(ip_addr, port, result)
            result = set(result)
        # Only the receiver can obtain the result.
        # Send result to other parties.
        _send_align_result(result, senders)
    else:
        sender = all_parties[_party_idx]
        port = int(sender.split(":")[2])
        ret_code = mdu.send_psi(port, input_set)
        if ret_code != 0:
            raise RuntimeError("Errors occurred in PSI send lib, "
                               "error code = {}".format(ret_code))
        result = _recv_align_result(sender)
    return result


def _find_party_idx(party_id, endpoint_list):
    """
    return the index of the given party id in the endpoint list
    :param party_id: party id
    :param endpoint_list: list of endpoints
    :return: the index of endpoint with the party_id, or -1 if not found
    """
    for idx in range(0, len(endpoint_list)):
        if party_id == int(endpoint_list[idx].split(":")[0]):
            return idx
    return -1


def _send_align_result(result, send_list):
    """
    Send align result to other data parties. This is used by the
    receiver when align.

    :param result: set. The align result.
    :param send_list: list. The data parties who receive the result.
    Each party is represented as "id:ip:port".
    :return:
    """
    for host in send_list:
        ip_addr = host.split(":")[1]
        port = int(host.split(":")[2])
        client = Client((ip_addr, port))
        client.send(result)
        client.close()


def _recv_align_result(host):
    """
    Receive align result from receiver.

    :param host: str. The host who is waiting for align result.
    The host is represented as "id:ip:port".
    :return: set. The received align result.
    """
    ip_addr = host.split(":")[1]
    port = int(host.split(":")[2])
    server = Listener((ip_addr, port))
    conn = server.accept()
J
jingqinghe 已提交
114 115 116 117
    result = conn.recv()
    conn.close()
    server.close()
    return result