提交 9a3ac286 编写于 作者: H He, Kai

fix ks bug & add one hot encoding map

上级 c6665795
......@@ -69,5 +69,6 @@ from .version import version
from .layers import mpc_math_op_patch
from . import input
from . import initializer
from . import metrics
mpc_math_op_patch.monkey_patch_mpc_variable()
......@@ -17,7 +17,10 @@ Import data_utils module.
from . import aby3
from . import alignment
from . import one_hot_encoding
from .alignment import *
from .one_hot_encoding import *
__all__ = []
__all__ += alignment.__all__
__all__ += one_hot_encoding.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module provide one hot encoding tools, implemented by OT (Oblivious Transfer)-based
PSI (Private Set Intersection) algorithm.
"""
from multiprocessing.connection import Client, Listener
import mpc_data_utils as mdu
__all__ = ['one_hot_encoding_map', ]
def one_hot_encoding_map(input_set, host_addr, is_client=True):
"""
A protocol to get agreement between 2 parties for encoding one
discrete feature to one hot vector via OT-PSI.
Args:
input_set (set:str): The set of possible feature value owned by this
party. Element of set is str, convert before pass in.
host_addr (str): The info of host_addr,e.g., ip:port
is_receiver (bool): True if this party plays as socket client
otherwise, plays as socket server
Return Val: dict, int.
dict key: feature value, dict value: corresponding idx in one hot vector.
int: length of one hot vector for this feature.
Examples:
.. code-block:: python
import paddle_fl.mpc.data_utils
import sys
is_client = sys.argv[1] == "1"
a = set([str(x) for x in range(7)])
b = set([str(x) for x in range(5, 10)])
addr = "127.0.0.1:33784"
ins = a if is_client else b
x, y = paddle_fl.mpc.data_utilsone_hot_encoding_map(ins, addr, is_client)
# y = 10
# x['5'] = 0, x['6'] = 1
# for those feature val owned only by one party, dict val shall
not be conflicting.
print(x, y)
"""
ip = host_addr.split(":")[0]
port = int(host_addr.split(":")[1])
if is_client:
intersection = input_set
intersection = mdu.recv_psi(ip, port, intersection)
intersection = sorted(list(intersection))
# Only the receiver can obtain the result.
# Send result to other parties.
else:
ret_code = mdu.send_psi(port, input_set)
if ret_code != 0:
raise RuntimeError("Errors occurred in PSI send lib, "
"error code = {}".format(ret_code))
if not is_client:
server = Listener((ip, port))
conn = Client((ip, port)) if is_client else server.accept()
if is_client:
conn.send(intersection)
diff_size_local = len(input_set) - len(intersection)
conn.send(diff_size_local)
diff_size_remote = conn.recv()
else:
intersection = conn.recv()
diff_size_local = len(input_set) - len(intersection)
diff_size_remote = conn.recv()
conn.send(diff_size_local)
conn.close()
if not is_client:
server.close()
ret = dict()
cnt = 0
for x in intersection:
ret[x] = cnt
cnt += 1
if is_client:
cnt += diff_size_remote
for x in [x for x in input_set if x not in intersection]:
ret[x] = cnt
cnt += 1
return ret, len(intersection) + diff_size_local + diff_size_remote
......@@ -50,8 +50,6 @@ class KSstatistic(MetricBase):
import paddle_fl.mpc
import numpy as np
# init the KSstatistic
ks = paddle_fl.mpc.metrics.KSstatistic('ks')
# suppose that batch_size is 128
batch_num = 100
......@@ -65,6 +63,10 @@ class KSstatistic(MetricBase):
preds = np.concatenate((class0_preds, class1_preds), axis=1)
labels = np.random.randint(2, size = (batch_size, 1))
# init the KSstatistic for each batch
# to get global ks statistic, init ks before for-loop
ks = paddle_fl.mpc.metrics.KSstatistic('ks')
ks.update(preds = preds, labels = labels)
# shall be some score closing to 0.1 as the preds are randomly assigned
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册