From 08ea0b2aaba466eef9c79e10dddeadb06348619c Mon Sep 17 00:00:00 2001 From: "He, Kai" Date: Fri, 11 Sep 2020 03:53:59 +0000 Subject: [PATCH] fix ut --- python/paddle_fl/mpc/tests/unittests/op_test.py | 8 ++++---- .../mpc/tests/unittests/test_op_base.py | 16 ++++++++-------- .../mpc/tests/unittests/test_op_conv.py | 6 +++--- .../mpc/tests/unittests/test_op_metric.py | 3 ++- .../mpc/tests/unittests/test_op_pool.py | 10 +++++----- 5 files changed, 22 insertions(+), 21 deletions(-) diff --git a/python/paddle_fl/mpc/tests/unittests/op_test.py b/python/paddle_fl/mpc/tests/unittests/op_test.py index 9064074..50e2b54 100644 --- a/python/paddle_fl/mpc/tests/unittests/op_test.py +++ b/python/paddle_fl/mpc/tests/unittests/op_test.py @@ -140,13 +140,13 @@ class OpTest(unittest.TestCase): target = kwargs['target'] - partys = [] + parties = [] for role in range(self.party_num): kwargs.update({'role': role}) - partys.append(Aby3Process(target=target, kwargs=kwargs)) - partys[-1].start() - for party in partys: + parties.append(Aby3Process(target=target, kwargs=kwargs)) + parties[-1].start() + for party in parties: party.join() if party.exception: return party.exception diff --git a/python/paddle_fl/mpc/tests/unittests/test_op_base.py b/python/paddle_fl/mpc/tests/unittests/test_op_base.py index 3f41f69..3ca84d8 100644 --- a/python/paddle_fl/mpc/tests/unittests/test_op_base.py +++ b/python/paddle_fl/mpc/tests/unittests/test_op_base.py @@ -77,13 +77,13 @@ class TestOpBase(unittest.TestCase): """ target = kwargs['target'] + parties = [] for role in range(self.party_num): kwargs.update({'role': role}) - party = Aby3Process(target=target, kwargs=kwargs) - party.start() - if role == self.party_num - 1: - party.join() - if party.exception: - return party.exception - else: - return (True,) + parties.append(Aby3Process(target=target, kwargs=kwargs)) + parties[-1].start() + for party in parties: + party.join() + if party.exception: + return party.exception + return (True,) diff --git a/python/paddle_fl/mpc/tests/unittests/test_op_conv.py b/python/paddle_fl/mpc/tests/unittests/test_op_conv.py index ee45be2..307c636 100644 --- a/python/paddle_fl/mpc/tests/unittests/test_op_conv.py +++ b/python/paddle_fl/mpc/tests/unittests/test_op_conv.py @@ -19,10 +19,10 @@ import unittest from multiprocessing import Manager import numpy as np - import test_op_base from op_test import OpTest import paddle_fl.mpc.data_utils.aby3 as aby3 +import mpc_data_utils as mdu import paddle.fluid as fluid import paddle.fluid.core as core @@ -190,7 +190,7 @@ class TestConv2dOp(OpTest): 'dilation': self.dilations } - share = lambda x: np.array([x * 65536/3] * 2).astype('int64') + share = lambda x: np.array([x * mdu.mpc_one_share] * 2).astype('int64') input = np.random.random(self.input_size) filter = np.random.uniform(-1, 1, self.filter_size) @@ -385,7 +385,7 @@ class TestConv2dOp_v2(OpTest): 'dilation': self.dilations } - share = lambda x: np.array([x * 65536/3] * 2).astype('int64') + share = lambda x: np.array([x * mdu.mpc_one_share] * 2).astype('int64') input = np.random.random(self.input_size) filter = np.random.uniform(-1, 1, self.filter_size) diff --git a/python/paddle_fl/mpc/tests/unittests/test_op_metric.py b/python/paddle_fl/mpc/tests/unittests/test_op_metric.py index 9260957..873ad69 100644 --- a/python/paddle_fl/mpc/tests/unittests/test_op_metric.py +++ b/python/paddle_fl/mpc/tests/unittests/test_op_metric.py @@ -20,6 +20,7 @@ import unittest import numpy as np import paddle.fluid as fluid import paddle_fl.mpc as pfl_mpc +import mpc_data_utils as mdu import test_op_base @@ -92,7 +93,7 @@ class TestOpPrecisionRecall(test_op_base.TestOpBase): self.threshold = np.random.random() preds, labels = [], [] self.exp_res = (0, [0] * 3) - share = lambda x: np.array([x * 65536/3] * 2).astype('int64').reshape( + share = lambda x: np.array([x * mdu.mpc_one_share] * 2).astype('int64').reshape( [2] + self.input_size) for _ in range(n): diff --git a/python/paddle_fl/mpc/tests/unittests/test_op_pool.py b/python/paddle_fl/mpc/tests/unittests/test_op_pool.py index 6856897..c519dc8 100644 --- a/python/paddle_fl/mpc/tests/unittests/test_op_pool.py +++ b/python/paddle_fl/mpc/tests/unittests/test_op_pool.py @@ -54,15 +54,15 @@ class TestOpPool2d(test_op_base.TestOpBase): def test_pool2d(self): data_1 = np.array( - [[[[1, 2, 3, 4, 0, 100], - [5, 6, 7, 8, 0, 100], + [[[[1, 2, 3, 4, 0, 100], + [5, 6, 7, 8, 0, 100], [9, 10, 11, 12, 0, 200], [13, 14, 15, 16, 0, 200]]]]).astype('float32') - + expected_out = np.array( - [[[[6, 8, 100], + [[[[6, 8, 100], [14, 16, 200]]]]).astype('float32') - print("input data_1: {} \n".format(data_1)) + # print("input data_1: {} \n".format(data_1)) data_1_shares = aby3.make_shares(data_1) -- GitLab