提交 17491aef 编写于 作者: H He,Kai

fix bugs

上级 9a3ac286
...@@ -21,7 +21,6 @@ namespace psi { ...@@ -21,7 +21,6 @@ namespace psi {
PseudorandomNumberGenerator::PseudorandomNumberGenerator(const block &seed) PseudorandomNumberGenerator::PseudorandomNumberGenerator(const block &seed)
: _ctr(0), _now_byte(0) { : _ctr(0), _now_byte(0) {
set_seed(seed); set_seed(seed);
refill_buffer();
} }
void PseudorandomNumberGenerator::set_seed(const block &b) { void PseudorandomNumberGenerator::set_seed(const block &b) {
...@@ -59,4 +58,10 @@ void PseudorandomNumberGenerator::get_array(void *res, size_t len) { ...@@ -59,4 +58,10 @@ void PseudorandomNumberGenerator::get_array(void *res, size_t len) {
} }
} }
template <>
bool PseudorandomNumberGenerator::get<bool>() {
uint8_t data;
get_array(&data, sizeof(data));
return data & 1;
}
} // namespace smc } // namespace smc
...@@ -338,7 +338,7 @@ def _transpile_type_and_shape(block): ...@@ -338,7 +338,7 @@ def _transpile_type_and_shape(block):
for op in block.ops: for op in block.ops:
if _is_supported_op(op.type): if _is_supported_op(op.type):
if op.type == 'fill_constant': if op.type == 'fill_constant':
op._set_attr(name='shape', val=(2L, 1L)) op._set_attr(name='shape', val=(2, 1))
# set default MPC value for fill_constant OP # set default MPC value for fill_constant OP
op._set_attr(name='value', val=MPC_ONE_SHARE) op._set_attr(name='value', val=MPC_ONE_SHARE)
op._set_attr(name='dtype', val=3) op._set_attr(name='dtype', val=3)
...@@ -482,7 +482,7 @@ def decrypt_model(mpc_model_dir, plain_model_path, mpc_model_filename=None, plai ...@@ -482,7 +482,7 @@ def decrypt_model(mpc_model_dir, plain_model_path, mpc_model_filename=None, plai
new_type = str(mpc_op.type)[len(MPC_OP_PREFIX):] new_type = str(mpc_op.type)[len(MPC_OP_PREFIX):]
mpc_op.desc.set_type(new_type) mpc_op.desc.set_type(new_type)
elif mpc_op.type == 'fill_constant': elif mpc_op.type == 'fill_constant':
mpc_op._set_attr(name='shape', val=(1L)) mpc_op._set_attr(name='shape', val=(1))
mpc_op._set_attr(name='value', val=1.0) mpc_op._set_attr(name='value', val=1.0)
mpc_op._set_attr(name='dtype', val=5) mpc_op._set_attr(name='dtype', val=5)
......
...@@ -17,14 +17,16 @@ This module test align in aby3 module. ...@@ -17,14 +17,16 @@ This module test align in aby3 module.
""" """
import unittest import unittest
from multiprocessing import Process import multiprocessing as mp
import paddle_fl.mpc.data_utils.alignment as alignment import paddle_fl.mpc.data_utils.alignment as alignment
class TestDataUtilsAlign(unittest.TestCase): class TestDataUtilsAlign(unittest.TestCase):
def run_align(self, input_set, party_id, endpoints, is_receiver): @staticmethod
def run_align(input_set, party_id, endpoints, is_receiver, ret_list):
""" """
Call align function in data_utils. Call align function in data_utils.
:param input_set: :param input_set:
...@@ -37,7 +39,7 @@ class TestDataUtilsAlign(unittest.TestCase): ...@@ -37,7 +39,7 @@ class TestDataUtilsAlign(unittest.TestCase):
party_id=party_id, party_id=party_id,
endpoints=endpoints, endpoints=endpoints,
is_receiver=is_receiver) is_receiver=is_receiver)
self.assertEqual(result, {'0'}) ret_list.append(result)
def test_align(self): def test_align(self):
""" """
...@@ -49,14 +51,27 @@ class TestDataUtilsAlign(unittest.TestCase): ...@@ -49,14 +51,27 @@ class TestDataUtilsAlign(unittest.TestCase):
set_1 = {'0', '10', '11', '111'} set_1 = {'0', '10', '11', '111'}
set_2 = {'0', '30', '33', '333'} set_2 = {'0', '30', '33', '333'}
party_0 = Process(target=self.run_align, args=(set_0, 0, endpoints, True)) mp.set_start_method('spawn')
party_1 = Process(target=self.run_align, args=(set_1, 1, endpoints, False))
party_2 = Process(target=self.run_align, args=(set_2, 2, endpoints, False)) manager = mp.Manager()
ret_list = manager.list()
party_0 = mp.Process(target=self.run_align, args=(set_0, 0, endpoints, True, ret_list))
party_1 = mp.Process(target=self.run_align, args=(set_1, 1, endpoints, False, ret_list))
party_2 = mp.Process(target=self.run_align, args=(set_2, 2, endpoints, False, ret_list))
party_1.start() party_1.start()
party_2.start() party_2.start()
party_0.start() party_0.start()
party_0.join() party_0.join()
party_1.join()
party_2.join()
self.assertEqual(3, len(ret_list))
self.assertEqual(ret_list[0], ret_list[1])
self.assertEqual(ret_list[0], ret_list[2])
self.assertEqual({'0'}, ret_list[0])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册