test_op_reshape.py 2.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module test sub op.

"""
import unittest
from multiprocessing import Manager
import numpy as np
import paddle.fluid as fluid
import paddle_fl.mpc as pfl_mpc
import paddle_fl.mpc.data_utils.aby3 as aby3
import test_op_base


class TestOpReshape(test_op_base.TestOpBase):

    def reshape(self, **kwargs):
        """
        Normal case.
        :param kwargs:
        :return:
        """

        role = kwargs['role']
        d_1 = kwargs['data_1'][role]
        return_results = kwargs['return_results']

        pfl_mpc.init("aby3", role, "localhost", self.server, int(self.port))
        data_1 = pfl_mpc.data(name='x', shape=[2, 2], dtype='int64')
        op_reshape = pfl_mpc.layers.reshape(data_1, [2, 1, 4])
        exe = fluid.Executor(place=fluid.CPUPlace())
        results = exe.run(feed={'x': d_1}, fetch_list=[op_reshape])

        self.assertEqual(results[0].shape, (2, 1, 4))
        return_results.append(results[0])

    def test_reshape(self):

        data_1 = np.full(shape=(2, 2), fill_value=2)
        data_1_shares = aby3.make_shares(data_1)
        data_1_all3shares = np.array([aby3.get_aby3_shares(data_1_shares, i) for i in range(3)])

        return_results = Manager().list()
        ret = self.multi_party_run(target=self.reshape,
                                   data_1=data_1_all3shares,
                                   return_results=return_results)
        self.assertEqual(ret[0], True)
        revealed = aby3.reconstruct(np.array(return_results))
        expected_out = np.array([[2, 2, 2, 2]])
        self.assertTrue(np.allclose(revealed, expected_out, atol=1e-4))
 

if __name__ == '__main__':
    unittest.main()