test_op_base.py 2.6 KB
Newer Older
J
update  
jingqinghe 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Set base config for op unit tests.
"""
from multiprocessing import Pipe, Process
import os
import traceback
import unittest

import redis


class Aby3Process(Process):
    """
    Extends from Process, evaluate the computation party in aby3.
    """
    def __init__(self, *args, **kwargs):
        Process.__init__(self, *args, **kwargs)
        self._pconn, self._cconn = Pipe()
        self._exception = None

    def run(self):
        """
        Override. Send any exceptions raised in
        subprocess to main process.
        """
        try:
            Process.run(self)
            self._cconn.send(None)
        except Exception as e:
            tb = traceback.format_exc()
            self._cconn.send((e, tb))

    @property
    def exception(self):
        """
        Get exception.
        """
        if self._pconn.poll():
            self._exception = self._pconn.recv()
        return self._exception


class TestOpBase(unittest.TestCase):
    def __init__(self, methodName='runTest'):
        super(TestOpBase, self).__init__(methodName)
        # set redis server and port
        self.server = os.environ['TEST_REDIS_IP']
        self.port = os.environ['TEST_REDIS_PORT']
        self.party_num = 3

    def setUp(self):
        """
        Connect redis and delete all keys in all databases on the current host.
        :return:
        """
        r = redis.Redis(host=self.server, port=int(self.port))
        r.flushall()

    def multi_party_run(self, **kwargs):
        """
        Run 3 parties with target function or other additional arguments.
        :param kwargs:
        :return:
        """
        target = kwargs['target']

H
fix ut  
He, Kai 已提交
80
        parties = []
J
update  
jingqinghe 已提交
81 82
        for role in range(self.party_num):
            kwargs.update({'role': role})
H
fix ut  
He, Kai 已提交
83 84 85 86 87 88 89
            parties.append(Aby3Process(target=target, kwargs=kwargs))
            parties[-1].start()
        for party in parties:
            party.join()
            if party.exception:
                return party.exception
        return (True,)