提交 17491aef 编写于 作者: H He,Kai

fix bugs

上级 9a3ac286
......@@ -21,7 +21,6 @@ namespace psi {
PseudorandomNumberGenerator::PseudorandomNumberGenerator(const block &seed)
: _ctr(0), _now_byte(0) {
set_seed(seed);
refill_buffer();
}
void PseudorandomNumberGenerator::set_seed(const block &b) {
......@@ -59,4 +58,10 @@ void PseudorandomNumberGenerator::get_array(void *res, size_t len) {
}
}
template <>
bool PseudorandomNumberGenerator::get<bool>() {
uint8_t data;
get_array(&data, sizeof(data));
return data & 1;
}
} // namespace smc
......@@ -338,7 +338,7 @@ def _transpile_type_and_shape(block):
for op in block.ops:
if _is_supported_op(op.type):
if op.type == 'fill_constant':
op._set_attr(name='shape', val=(2L, 1L))
op._set_attr(name='shape', val=(2, 1))
# set default MPC value for fill_constant OP
op._set_attr(name='value', val=MPC_ONE_SHARE)
op._set_attr(name='dtype', val=3)
......@@ -482,7 +482,7 @@ def decrypt_model(mpc_model_dir, plain_model_path, mpc_model_filename=None, plai
new_type = str(mpc_op.type)[len(MPC_OP_PREFIX):]
mpc_op.desc.set_type(new_type)
elif mpc_op.type == 'fill_constant':
mpc_op._set_attr(name='shape', val=(1L))
mpc_op._set_attr(name='shape', val=(1))
mpc_op._set_attr(name='value', val=1.0)
mpc_op._set_attr(name='dtype', val=5)
......
......@@ -17,14 +17,16 @@ This module test align in aby3 module.
"""
import unittest
from multiprocessing import Process
import multiprocessing as mp
import paddle_fl.mpc.data_utils.alignment as alignment
class TestDataUtilsAlign(unittest.TestCase):
def run_align(self, input_set, party_id, endpoints, is_receiver):
@staticmethod
def run_align(input_set, party_id, endpoints, is_receiver, ret_list):
"""
Call align function in data_utils.
:param input_set:
......@@ -37,7 +39,7 @@ class TestDataUtilsAlign(unittest.TestCase):
party_id=party_id,
endpoints=endpoints,
is_receiver=is_receiver)
self.assertEqual(result, {'0'})
ret_list.append(result)
def test_align(self):
"""
......@@ -49,14 +51,27 @@ class TestDataUtilsAlign(unittest.TestCase):
set_1 = {'0', '10', '11', '111'}
set_2 = {'0', '30', '33', '333'}
party_0 = Process(target=self.run_align, args=(set_0, 0, endpoints, True))
party_1 = Process(target=self.run_align, args=(set_1, 1, endpoints, False))
party_2 = Process(target=self.run_align, args=(set_2, 2, endpoints, False))
mp.set_start_method('spawn')
manager = mp.Manager()
ret_list = manager.list()
party_0 = mp.Process(target=self.run_align, args=(set_0, 0, endpoints, True, ret_list))
party_1 = mp.Process(target=self.run_align, args=(set_1, 1, endpoints, False, ret_list))
party_2 = mp.Process(target=self.run_align, args=(set_2, 2, endpoints, False, ret_list))
party_1.start()
party_2.start()
party_0.start()
party_0.join()
party_1.join()
party_2.join()
self.assertEqual(3, len(ret_list))
self.assertEqual(ret_list[0], ret_list[1])
self.assertEqual(ret_list[0], ret_list[2])
self.assertEqual({'0'}, ret_list[0])
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册