diff --git a/python/paddle_fl/mpc/layers/math.py b/python/paddle_fl/mpc/layers/math.py index 7e9d27007c839b59f3a82e3a087eb172963a044f..7eec73e674d877a43a90f8cfa360e2dc22b2daed 100644 --- a/python/paddle_fl/mpc/layers/math.py +++ b/python/paddle_fl/mpc/layers/math.py @@ -24,6 +24,7 @@ __all__ = [ 'square', 'sum', 'square_error_cost', + 'reduce_sum' ] @@ -128,3 +129,71 @@ def square_error_cost(input, label): inputs={'X': [minus_out]}, outputs={'Out': [square_out]}) return square_out + + + +def reduce_sum(input, dim=None, keep_dim=False, name=None): + """ + Computes the sum of tensor elements over the given dimension. + + Args: + input (MpcVariable) The input of sum op name(basestring|None): Name of the output. + dim (list|int, optional): The dimensions along which the sum is performed. If + :attr:`None`, sum all elements of :attr:`input` and return a + Tensor variable with a single element, otherwise must be in the + range :math:`[-rank(input), rank(input))`. If :math:`dim[i] < 0`, + the dimension to reduce is :math:`rank + dim[i]`. + NOTE: 'dim' should not contain 0, becausedims[0] is share number. + keep_dim (bool, optional): Whether to reserve the reduced dimension in the + output Tensor. The result tensor will have one fewer dimension + than the :attr:`input` unless :attr:`keep_dim` is true, default + value is False. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + Returns: + Variable: Tensor, results of summation operation on the specified dim of input tensor, + it's data type is the same as input's Tensor. + Raises: + TypeError, if out data type is different with the input data type. + + Returns: + out(MpcVariable): (Tensor) The output of mean op + Examples: + .. code-block:: python + + import paddle_fl.mpc as pfl_mpc + + pfl_mpc.init("aby3", int(args.role), "localhost", args.server, int(args.port)) + data_1 = pfl_mpc.data(name='x', shape=[3, 3], dtype='int64') + pfl_mpc.layers.reshape(data_1, [1, 2]) # shape: [2, 1, 1] + # data_1 = np.full(shape=(3, 4), fill_value=2) + # reduce_sum: 24 + """ + if dim is not None and not isinstance(dim, list): + dim = [dim] + + if dim != None and dim != []: + if 0 in dim: + raise ValueError( + "'dim' should not contain 0, because dim[0] is share number." + ) + else: + dim = [i for i in range(len(input.shape))][1:] + + attrs = { + 'dim': dim, + 'keep_dim': keep_dim, + 'reduce_all': False + } + check_mpc_variable_and_dtype( + input, 'input', ['int64'], 'reduce_sum') + helper = MpcLayerHelper('reduce_sum', **locals()) + out = helper.create_mpc_variable_for_type_inference(dtype=helper.input_dtype()) + helper.append_op( + type='reduce_sum', + inputs={'X': input}, + outputs={'Out': out}, + attrs=attrs) + return out + + diff --git a/python/paddle_fl/mpc/layers/ml.py b/python/paddle_fl/mpc/layers/ml.py index cd67a8556cf8e176b64e6f435732ac2d18d1b02b..c45953acaeb2e450dd4098a05868350ec6a1b85b 100644 --- a/python/paddle_fl/mpc/layers/ml.py +++ b/python/paddle_fl/mpc/layers/ml.py @@ -21,6 +21,7 @@ import mpc_data_utils as mdu from paddle.fluid.data_feeder import check_type, check_dtype import paddle.fluid.layers.utils as utils from paddle.fluid.initializer import Constant +from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.param_attr import ParamAttr from paddle.fluid.framework import Variable from ..framework import MpcVariable @@ -35,6 +36,7 @@ __all__ = [ 'softmax_with_cross_entropy', 'pool2d', 'batch_norm', + 'reshape', ] @@ -550,3 +552,131 @@ def batch_norm(input, type="mpc_batch_norm", inputs=inputs, outputs=outputs, attrs=attrs) return helper.append_activation(batch_norm_out) + + +def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): + """ + This operator changes the shape of ``x`` without changing its data. + + The target shape can be given by ``shape`` or ``actual_shape``. + When ``shape`` and ``actual_shape`` are set at the same time, + ``actual_shape`` has a higher priority than ``shape`` + but at this time ``shape`` can only be an integer list or tuple, and ``shape`` still should be set correctly to + guarantee shape inference in compile-time. + + Some tricks exist when specifying the target shape. + + 1. -1 means the value of this dimension is inferred from the total element + number of x and remaining dimensions. Thus one and only one dimension can + be set -1. + + 2. 0 means the actual dimension value is going to be copied from the + corresponding dimension of x. The index of 0s in shape can not exceed + the dimension of x. + + Args: + x(Variable): A ``Tensor`` or ``LoDTensor`` . The data type is ``int64``. + shape(list|tuple|Variable): Define the target shape. At most one dimension of the target shape can be -1. + The data type is ``int32`` . If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [1]. + If ``shape`` is an Variable, it should be an 1-D Tensor . + actual_shape(variable, optional): An 1-D ``Tensor`` or ``LoDTensor`` . The data type is ``int32`` . If provided, reshape + according to this given shape rather than ``shape`` specifying shape. + That is to say ``actual_shape`` has a higher priority + than ``shape(list|tuple)`` but not ``shape(Variable)``. \ + This argument ``actual_shape`` will be removed in a future version. \ + act (str, optional): The non-linear activation to be applied to the reshaped input. Default None. + inplace(bool, optional): If ``inplace`` is True, the input and output of ``layers.reshape`` + are the same variable. Otherwise, the input and output of + ``layers.reshape`` are different variable. Default False. Note that if ``x`` + is more than one OPs' input, ``inplace`` must be False. + name(str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Variable: A ``Tensor`` or ``LoDTensor``. The data type is same as ``x``. It is a new tensor variable if ``inplace`` is ``False``, otherwise it is ``x``. If ``act`` is None, return the reshaped tensor variable, otherwise return the activated tensor variable. + + + Examples: + .. code-block:: python + import paddle_fl.mpc as pfl_mpc + + pfl_mpc.init("aby3", int(args.role), "localhost", args.server, int(args.port)) + data_1 = pfl_mpc.data(name='x', shape=[3, 3], dtype='int64') + op_reshape = pfl_mpc.layers.reshape(data_1, [2, 1, 9]) + """ + + check_mpc_variable_and_dtype( + x, 'x', ['int64'], 'reshape') + check_type(shape, 'shape', (list, tuple, Variable), 'reshape') + check_type(actual_shape, 'actual_shape', (Variable, type(None)), 'reshape') + + helper = MpcLayerHelper("reshape2", **locals()) + _helper = LayerHelper("reshape2", **locals()) + + def get_new_shape_tensor(list_shape): + new_shape_tensor = [] + for dim in list_shape: + if isinstance(dim, Variable): + dim.stop_gradient = True + new_shape_tensor.append(dim) + else: + assert (isinstance(dim, int)) + temp_out = _helper.create_variable_for_type_inference('int32') + fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out) + new_shape_tensor.append(temp_out) + return new_shape_tensor + + def get_attr_shape(list_shape): + unk_dim_idx = -1 + attrs_shape = [] + for dim_idx, dim_size in enumerate(list_shape): + if isinstance(dim_size, Variable): + attrs_shape.append(-1) + else: + attrs_shape.append(dim_size) + if dim_size == -1: + assert unk_dim_idx == -1, ( + "Only one dimension value of 'shape' in reshape can " + "be -1. But received shape[%d] is also -1." % dim_idx) + unk_dim_idx = dim_idx + elif dim_size == 0: + assert dim_idx < len(x.shape), ( + "The index of 0 in `shape` must be less than " + "the input tensor X's dimensions. " + "But received shape[%d] = 0, X's dimensions = %d." % + (dim_idx, len(x.shape))) + else: + assert dim_size > 0, ( + "Each dimension value of 'shape' in reshape must not " + "be negative except one unknown dimension. " + "But received shape[%d] = %s." % + (dim_idx, str(dim_size))) + return attrs_shape + + inputs = {"X": x} + attrs = {} + if isinstance(shape, Variable): + shape.stop_gradient = True + inputs["Shape"] = shape + elif isinstance(shape, (list, tuple)): + assert len(shape) > 0, ("The size of 'shape' in reshape can't be zero, " + "but received %s." % len(shape)) + attrs["shape"] = get_attr_shape(shape) + + if utils._contain_var(shape): + inputs['ShapeTensor'] = get_new_shape_tensor(shape) + elif isinstance(actual_shape, Variable): + actual_shape.stop_gradient = True + inputs["Shape"] = actual_shape + + out = x if inplace else helper.create_mpc_variable_for_type_inference( + dtype=x.dtype) + x_shape = helper.create_mpc_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type="reshape2", + inputs=inputs, + attrs=attrs, + outputs={"Out": out, + "XShape": x_shape}) + + return helper.append_activation(out) diff --git a/python/paddle_fl/mpc/tests/unittests/run_test_example.sh b/python/paddle_fl/mpc/tests/unittests/run_test_example.sh index 30ede58dfbbc401651ae6d832214e6b2d70f785c..3271ea21d24dbb9f1edd28a9fab1c32c4ab01c4e 100644 --- a/python/paddle_fl/mpc/tests/unittests/run_test_example.sh +++ b/python/paddle_fl/mpc/tests/unittests/run_test_example.sh @@ -26,6 +26,8 @@ TEST_MODULES=("test_datautils_aby3" "test_op_conv" "test_op_pool" "test_op_metric" +"test_op_reshape" +"test_op_reduce_sum" ) # run unittest diff --git a/python/paddle_fl/mpc/tests/unittests/test_op_reduce_sum.py b/python/paddle_fl/mpc/tests/unittests/test_op_reduce_sum.py new file mode 100644 index 0000000000000000000000000000000000000000..0c7260f3a370cc82974872f0b05f586db80c7ea9 --- /dev/null +++ b/python/paddle_fl/mpc/tests/unittests/test_op_reduce_sum.py @@ -0,0 +1,66 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module test sub op. + +""" +import unittest +from multiprocessing import Manager +import numpy as np +import paddle.fluid as fluid +import paddle_fl.mpc as pfl_mpc +import paddle_fl.mpc.data_utils.aby3 as aby3 +import test_op_base + + +class TestOpReduceSum(test_op_base.TestOpBase): + + def reduce_sum(self, **kwargs): + """ + Normal case. + :param kwargs: + :return: + """ + + role = kwargs['role'] + d_1 = kwargs['data_1'][role] + return_results = kwargs['return_results'] + + pfl_mpc.init("aby3", role, "localhost", self.server, int(self.port)) + data_1 = pfl_mpc.data(name='x', shape=[3, 4], dtype='int64') + op_reduce_sum = pfl_mpc.layers.reduce_sum(data_1, [1, 2], keep_dim=True) + exe = fluid.Executor(place=fluid.CPUPlace()) + results = exe.run(feed={'x': d_1}, fetch_list=[op_reduce_sum]) + + self.assertEqual(results[0].shape, (2, 1, 1)) + return_results.append(results[0]) + + def test_reduce_sum(self): + + data_1 = np.full(shape=(3, 4), fill_value=2) + data_1_shares = aby3.make_shares(data_1) + data_1_all3shares = np.array([aby3.get_aby3_shares(data_1_shares, i) for i in range(3)]) + + return_results = Manager().list() + ret = self.multi_party_run(target=self.reduce_sum, + data_1=data_1_all3shares, + return_results=return_results) + self.assertEqual(ret[0], True) + revealed = aby3.reconstruct(np.array(return_results)) + expected_out = np.array([[24]]) + self.assertTrue(np.allclose(revealed, expected_out, atol=1e-4)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle_fl/mpc/tests/unittests/test_op_reshape.py b/python/paddle_fl/mpc/tests/unittests/test_op_reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..790fc42080684ae3c0198dc63b9ea643da1a7763 --- /dev/null +++ b/python/paddle_fl/mpc/tests/unittests/test_op_reshape.py @@ -0,0 +1,66 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module test sub op. + +""" +import unittest +from multiprocessing import Manager +import numpy as np +import paddle.fluid as fluid +import paddle_fl.mpc as pfl_mpc +import paddle_fl.mpc.data_utils.aby3 as aby3 +import test_op_base + + +class TestOpReshape(test_op_base.TestOpBase): + + def reshape(self, **kwargs): + """ + Normal case. + :param kwargs: + :return: + """ + + role = kwargs['role'] + d_1 = kwargs['data_1'][role] + return_results = kwargs['return_results'] + + pfl_mpc.init("aby3", role, "localhost", self.server, int(self.port)) + data_1 = pfl_mpc.data(name='x', shape=[2, 2], dtype='int64') + op_reshape = pfl_mpc.layers.reshape(data_1, [2, 1, 4]) + exe = fluid.Executor(place=fluid.CPUPlace()) + results = exe.run(feed={'x': d_1}, fetch_list=[op_reshape]) + + self.assertEqual(results[0].shape, (2, 1, 4)) + return_results.append(results[0]) + + def test_reshape(self): + + data_1 = np.full(shape=(2, 2), fill_value=2) + data_1_shares = aby3.make_shares(data_1) + data_1_all3shares = np.array([aby3.get_aby3_shares(data_1_shares, i) for i in range(3)]) + + return_results = Manager().list() + ret = self.multi_party_run(target=self.reshape, + data_1=data_1_all3shares, + return_results=return_results) + self.assertEqual(ret[0], True) + revealed = aby3.reconstruct(np.array(return_results)) + expected_out = np.array([[2, 2, 2, 2]]) + self.assertTrue(np.allclose(revealed, expected_out, atol=1e-4)) + + +if __name__ == '__main__': + unittest.main()