提交 08ea0b2a 编写于 作者: H He, Kai

fix ut

上级 637e476b
...@@ -140,13 +140,13 @@ class OpTest(unittest.TestCase): ...@@ -140,13 +140,13 @@ class OpTest(unittest.TestCase):
target = kwargs['target'] target = kwargs['target']
partys = [] parties = []
for role in range(self.party_num): for role in range(self.party_num):
kwargs.update({'role': role}) kwargs.update({'role': role})
partys.append(Aby3Process(target=target, kwargs=kwargs)) parties.append(Aby3Process(target=target, kwargs=kwargs))
partys[-1].start() parties[-1].start()
for party in partys: for party in parties:
party.join() party.join()
if party.exception: if party.exception:
return party.exception return party.exception
......
...@@ -77,13 +77,13 @@ class TestOpBase(unittest.TestCase): ...@@ -77,13 +77,13 @@ class TestOpBase(unittest.TestCase):
""" """
target = kwargs['target'] target = kwargs['target']
parties = []
for role in range(self.party_num): for role in range(self.party_num):
kwargs.update({'role': role}) kwargs.update({'role': role})
party = Aby3Process(target=target, kwargs=kwargs) parties.append(Aby3Process(target=target, kwargs=kwargs))
party.start() parties[-1].start()
if role == self.party_num - 1: for party in parties:
party.join() party.join()
if party.exception: if party.exception:
return party.exception return party.exception
else:
return (True,) return (True,)
...@@ -19,10 +19,10 @@ import unittest ...@@ -19,10 +19,10 @@ import unittest
from multiprocessing import Manager from multiprocessing import Manager
import numpy as np import numpy as np
import test_op_base import test_op_base
from op_test import OpTest from op_test import OpTest
import paddle_fl.mpc.data_utils.aby3 as aby3 import paddle_fl.mpc.data_utils.aby3 as aby3
import mpc_data_utils as mdu
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
...@@ -190,7 +190,7 @@ class TestConv2dOp(OpTest): ...@@ -190,7 +190,7 @@ class TestConv2dOp(OpTest):
'dilation': self.dilations 'dilation': self.dilations
} }
share = lambda x: np.array([x * 65536/3] * 2).astype('int64') share = lambda x: np.array([x * mdu.mpc_one_share] * 2).astype('int64')
input = np.random.random(self.input_size) input = np.random.random(self.input_size)
filter = np.random.uniform(-1, 1, self.filter_size) filter = np.random.uniform(-1, 1, self.filter_size)
...@@ -385,7 +385,7 @@ class TestConv2dOp_v2(OpTest): ...@@ -385,7 +385,7 @@ class TestConv2dOp_v2(OpTest):
'dilation': self.dilations 'dilation': self.dilations
} }
share = lambda x: np.array([x * 65536/3] * 2).astype('int64') share = lambda x: np.array([x * mdu.mpc_one_share] * 2).astype('int64')
input = np.random.random(self.input_size) input = np.random.random(self.input_size)
filter = np.random.uniform(-1, 1, self.filter_size) filter = np.random.uniform(-1, 1, self.filter_size)
......
...@@ -20,6 +20,7 @@ import unittest ...@@ -20,6 +20,7 @@ import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle_fl.mpc as pfl_mpc import paddle_fl.mpc as pfl_mpc
import mpc_data_utils as mdu
import test_op_base import test_op_base
...@@ -92,7 +93,7 @@ class TestOpPrecisionRecall(test_op_base.TestOpBase): ...@@ -92,7 +93,7 @@ class TestOpPrecisionRecall(test_op_base.TestOpBase):
self.threshold = np.random.random() self.threshold = np.random.random()
preds, labels = [], [] preds, labels = [], []
self.exp_res = (0, [0] * 3) self.exp_res = (0, [0] * 3)
share = lambda x: np.array([x * 65536/3] * 2).astype('int64').reshape( share = lambda x: np.array([x * mdu.mpc_one_share] * 2).astype('int64').reshape(
[2] + self.input_size) [2] + self.input_size)
for _ in range(n): for _ in range(n):
......
...@@ -62,7 +62,7 @@ class TestOpPool2d(test_op_base.TestOpBase): ...@@ -62,7 +62,7 @@ class TestOpPool2d(test_op_base.TestOpBase):
expected_out = np.array( expected_out = np.array(
[[[[6, 8, 100], [[[[6, 8, 100],
[14, 16, 200]]]]).astype('float32') [14, 16, 200]]]]).astype('float32')
print("input data_1: {} \n".format(data_1)) # print("input data_1: {} \n".format(data_1))
data_1_shares = aby3.make_shares(data_1) data_1_shares = aby3.make_shares(data_1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册