diff --git a/core/paddlefl_mpc/mpc_protocol/aby3_operators.h b/core/paddlefl_mpc/mpc_protocol/aby3_operators.h index c129ddac2c5188ae13ba813392ec18d3675ca669..77325ec609cc40e09d6dbaa33776072f124130da 100644 --- a/core/paddlefl_mpc/mpc_protocol/aby3_operators.h +++ b/core/paddlefl_mpc/mpc_protocol/aby3_operators.h @@ -319,13 +319,17 @@ public: auto a_tuple = from_tensor(in); auto a_ = std::get<0>(a_tuple).get(); - auto b_tuple = from_tensor(pos_info); - auto b_ = std::get<0>(b_tuple).get(); - auto out_tuple = from_tensor(out); auto out_ = std::get<0>(out_tuple).get(); - a_->max_pooling(out_, b_); + if (pos_info) { + auto b_tuple = from_tensor(pos_info); + auto b_ = std::get<0>(b_tuple).get(); + + a_->max_pooling(out_, b_); + } else { + a_->max_pooling(out_, nullptr); + } } void inverse_square_root(const Tensor* in, Tensor* out) override { diff --git a/core/paddlefl_mpc/operators/mpc_mean_normalize_op.cc b/core/paddlefl_mpc/operators/mpc_mean_normalize_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d268ede45360b1edcfdf32307608fc000d020d60 --- /dev/null +++ b/core/paddlefl_mpc/operators/mpc_mean_normalize_op.cc @@ -0,0 +1,154 @@ + +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "mpc_mean_normalize_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class MpcMeanNormalizationOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("Min"), true, + platform::errors::InvalidArgument( + "Input(Min) should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Max"), true, + platform::errors::InvalidArgument("Input(Max) should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Mean"), true, + platform::errors::InvalidArgument("Input(Mean) should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("SampleNum"), true, + platform::errors::InvalidArgument("Input(Sample) should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Range"), true, + platform::errors::InvalidArgument( + "Output(Range) should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput("MeanOut"), true, + platform::errors::InvalidArgument( + "Output(Meanor) should not be null.")); + + int64_t total_sample_num = + static_cast(ctx->Attrs().Get("total_sample_num")); + + auto min_dims = ctx->GetInputDim("Min"); + auto max_dims = ctx->GetInputDim("Max"); + auto mean_dims = ctx->GetInputDim("Mean"); + auto sample_num_dims = ctx->GetInputDim("SampleNum"); + + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(min_dims, max_dims, + platform::errors::InvalidArgument( + "The dimension of Input(Min) and " + "Input(Max) should be the same." + "But received (%d) != (%d)", + min_dims, max_dims)); + PADDLE_ENFORCE_EQ(min_dims, mean_dims, + platform::errors::InvalidArgument( + "The dimension of Input(Min) and " + "Input(Max) should be the same." + "But received (%d) != (%d)", + min_dims, mean_dims)); + PADDLE_ENFORCE_EQ( + min_dims.size(), 3, + platform::errors::InvalidArgument( + "The dimension of Input(Min) should be equal to 3 " + "(share_num, party_num, feature_num). But received (%d)", + min_dims.size())); + + PADDLE_ENFORCE_EQ( + sample_num_dims.size(), 2, + platform::errors::InvalidArgument( + "The dimension of Input(SampleNum) should be equal to 3 " + "(share_num, party_num). But received (%d)", + sample_num_dims.size())); + + PADDLE_ENFORCE_EQ( + sample_num_dims[1], min_dims[1], + platform::errors::InvalidArgument( + "The party num of Input(SampleNum) and Input(Min) " + "should be equal But received (%d) != (%d)", + sample_num_dims[1], min_dims[1])); + } + + ctx->SetOutputDim("Range", {mean_dims[0], mean_dims[2]}); + ctx->SetOutputDim("MeanOut", {mean_dims[0], mean_dims[2]}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Min"), + ctx.device_context()); + } +}; + +class MpcMeanNormalizationOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Min", + "(Tensor, default Tensor) A 2-D tensor with shape [P, N], " + "where P is the party num and N is the feature num. Each row contains " + " the local min feature val of N features."); + AddInput("Max", + "(Tensor, default Tensor) A 2-D tensor with shape [P, N], " + "where P is the party num and N is the feature num. Each row contains " + " the local max feature val of N features."); + AddInput("Mean", + "(Tensor, default Tensor) A 2-D tensor with shape [P, N], " + "where P is the party num and N is the feature num. Each row contains " + " the local mean feature val of N features."); + AddInput("SampleNum", + "(Tensor, default Tensor) A 1-D tensor with shape [P], " + "where P is the party num. Each element contains " + "sample num of party_i."); + AddOutput("Range", + "(Tensor, default Tensor) A 1-D tensor with shape [N], " + "where N is the feature num. Each element contains " + "global range of feature_i."); + AddOutput("MeanOut", + "(Tensor, default Tensor) A 1-D tensor with shape [N], " + "where N is the feature num. Each element contains " + "global mean of feature_i."); + AddAttr("total_sample_num", "(int) Sum of sample nums from all party."); + AddComment(R"DOC( +Mean normalization Operator. +When given Input(Min), Input(Max), Input(Mean) and Input(SampleNum), +this operator can be used to compute global range and mean for further feature +scaling. +Output(Range) is the global range of all features. +Output(MeanOut) is the global mean of all features. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + mpc_mean_normalize, ops::MpcMeanNormalizationOp, ops::MpcMeanNormalizationOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + mpc_mean_normalize, + ops::MpcMeanNormalizationKernel); diff --git a/core/paddlefl_mpc/operators/mpc_mean_normalize_op.h b/core/paddlefl_mpc/operators/mpc_mean_normalize_op.h new file mode 100644 index 0000000000000000000000000000000000000000..91f8ba24e5bacb3c52992fb2fa137b444ec0b079 --- /dev/null +++ b/core/paddlefl_mpc/operators/mpc_mean_normalize_op.h @@ -0,0 +1,94 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "mpc_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class MpcMeanNormalizationKernel : public MpcOpKernel { + public: + void ComputeImpl(const framework::ExecutionContext& context) const override { + const Tensor* min = context.Input("Min"); + const Tensor* max = context.Input("Max"); + const Tensor* mean = context.Input("Mean"); + const Tensor* sample_num = context.Input("SampleNum"); + + Tensor* range = context.Output("Range"); + Tensor* mean_out = context.Output("MeanOut"); + + int share_num = min->dims()[0]; + int party_num = min->dims()[1]; + int feat_num = min->dims()[2]; + + Tensor neg_min; + neg_min.mutable_data(min->dims(), context.GetPlace(), 0); + + Tensor neg_min_global; + Tensor max_global; + + neg_min_global.mutable_data( + framework::make_ddim({share_num, 1, feat_num}), context.GetPlace(), 0); + max_global.mutable_data( + framework::make_ddim({share_num, 1, feat_num}), context.GetPlace(), 0); + + mpc::MpcInstance::mpc_instance()->mpc_protocol() + ->mpc_operators()->neg(min, &neg_min); + + mpc::MpcInstance::mpc_instance()->mpc_protocol() + ->mpc_operators()->max_pooling(&neg_min, &neg_min_global, nullptr); + + mpc::MpcInstance::mpc_instance()->mpc_protocol() + ->mpc_operators()->max_pooling(max, &max_global, nullptr); + + range->mutable_data( + framework::make_ddim({share_num, 1, feat_num}), context.GetPlace(), 0); + + mpc::MpcInstance::mpc_instance()->mpc_protocol() + ->mpc_operators()->add(&max_global, &neg_min_global, range); + + range->mutable_data( + framework::make_ddim({share_num, feat_num}), context.GetPlace(), 0); + + // TODO: get total_sample_num by reduing size + int total_sample_num = context.Attr("total_sample_num"); + + Tensor sample_num_; + + sample_num_.ShareDataWith(*sample_num); + + sample_num_.mutable_data( + framework::make_ddim({share_num, 1, party_num}), context.GetPlace(), 0); + + mean_out->mutable_data( + framework::make_ddim({share_num, 1, feat_num}), context.GetPlace(), 0); + + mpc::MpcInstance::mpc_instance()->mpc_protocol() + ->mpc_operators()->matmul(&sample_num_, mean, mean_out); + + mean_out->mutable_data( + framework::make_ddim({share_num, feat_num}), context.GetPlace(), 0); + + mpc::MpcInstance::mpc_instance()->mpc_protocol() + ->mpc_operators()->scale(mean_out, 1.0 / total_sample_num, mean_out); + +} +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle_fl/mpc/layers/__init__.py b/python/paddle_fl/mpc/layers/__init__.py index 3f6b0f29d084d18224a38fd2853d7ede28b2561c..aebd09d388abb0883744403ac5ec2109afffbf67 100644 --- a/python/paddle_fl/mpc/layers/__init__.py +++ b/python/paddle_fl/mpc/layers/__init__.py @@ -37,6 +37,8 @@ from . import rnn from .rnn import * from . import metric_op from .metric_op import * +from . import data_preprocessing +from .data_preprocessing import * __all__ = [] __all__ += basic.__all__ @@ -46,3 +48,4 @@ __all__ += ml.__all__ __all__ += compare.__all__ __all__ += conv.__all__ __all__ += metric_op.__all__ +__all__ += data_preprocessing.__all__ diff --git a/python/paddle_fl/mpc/layers/data_preprocessing.py b/python/paddle_fl/mpc/layers/data_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..e21999441a34f18fddf41c966d652ca6c92d2317 --- /dev/null +++ b/python/paddle_fl/mpc/layers/data_preprocessing.py @@ -0,0 +1,202 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +mpc data preprocessing op layers. +""" +from paddle.fluid.data_feeder import check_type, check_dtype +from ..framework import check_mpc_variable_and_dtype +from ..mpc_layer_helper import MpcLayerHelper + +__all__ = ['mean_normalize'] + +def mean_normalize(f_min, f_max, f_mean, sample_num, total_sample_num): + ''' + Mean normalization is a method used to normalize the range of independent + variables or features of data. + Refer to: + https://en.wikipedia.org/wiki/Feature_scaling#Mean_normalization + + Args: + f_min (Variable): A 2-D tensor with shape [P, N], where P is the party + num and N is the feature num. Each row contains the + local min feature val of N features. + f_max (Variable): A 2-D tensor with shape [P, N], where P is the party + num and N is the feature num. Each row contains the + local max feature val of N features. + f_mean (Variable): A 2-D tensor with shape [P, N], where P is the party + num and N is the feature num. Each row contains the + local min feature val of N features. + sample_num (Variable): A 1-D tensor with shape [P], where P is the + party num. Each element contains sample num + of party_i. + total_sample_num (int): Sum of sample nums from all party. + + Returns: + f_range (Variable): A 1-D tensor with shape [N], where N is the + feature num. Each element contains global + range of feature_i. + f_mean_out (Variable): A 1-D tensor with shape [N], where N is the + feature num. Each element contains global + range of feature_i. + Examples: + .. code-block:: python + from multiprocessing import Manager + from multiprocessing import Process + import numpy as np + import paddle.fluid as fluid + import paddle_fl.mpc as pfl_mpc + import mpc_data_utils as mdu + import paddle_fl.mpc.data_utils.aby3 as aby3 + + + redis_server = "127.0.0.1" + redis_port = 9937 + test_f_num = 100 + # party i owns 2 + 2*i rows of data + test_row_split = range(2, 10, 2) + + + def mean_norm_naive(f_mat): + ma = np.amax(f_mat, axis=0) + mi = np.amin(f_mat, axis=0) + return ma - mi, np.mean(f_mat, axis=0) + + + def gen_data(f_num, sample_nums): + f_mat = np.random.rand(np.sum(sample_nums), f_num) + + f_min, f_max, f_mean = [], [], [] + + prev_idx = 0 + + for n in sample_nums: + i = prev_idx + j = i + n + + ma = np.amax(f_mat[i:j], axis=0) + mi = np.amin(f_mat[i:j], axis=0) + me = np.mean(f_mat[i:j], axis=0) + + f_min.append(mi) + f_max.append(ma) + f_mean.append(me) + + prev_idx += n + + f_min = np.array(f_min).reshape(sample_nums.size, f_num) + f_max = np.array(f_max).reshape(sample_nums.size, f_num) + f_mean = np.array(f_mean).reshape(sample_nums.size, f_num) + + return f_mat, f_min, f_max, f_mean + + + class MeanNormDemo: + + def mean_normalize(self, **kwargs): + """ + mean_normalize op ut + :param kwargs: + :return: + """ + role = kwargs['role'] + + pfl_mpc.init("aby3", role, "localhost", redis_server, redis_port) + + mi = pfl_mpc.data(name='mi', shape=self.input_size, dtype='int64') + ma = pfl_mpc.data(name='ma', shape=self.input_size, dtype='int64') + me = pfl_mpc.data(name='me', shape=self.input_size, dtype='int64') + sn = pfl_mpc.data(name='sn', shape=self.input_size, dtype='int64') + + out0, out1 = pfl_mpc.layers.mean_normalize(f_min=mi, f_max=ma, + f_mean=me, sample_num=sn, total_sample_num=self.total_num) + + exe = fluid.Executor(place=fluid.CPUPlace()) + + f_range, f_mean = exe.run(feed={'mi': kwargs['min'], + 'ma': kwargs['max'], 'me': kwargs['mean'], + 'sn': kwargs['sample_num']},fetch_list=[out0, out1]) + + self.f_range_list.append(f_range) + self.f_mean_list.append(f_mean) + + def run(self): + f_nums = test_f_num + sample_nums = np.array(test_row_split) + mat, mi, ma, me = gen_data(f_nums, sample_nums) + + self.input_size = [len(sample_nums), f_nums] + self.total_num = mat.shape[0] + + # simulating encrypting data + share = lambda x: np.array([x * mdu.mpc_one_share] * 2).astype('int64').reshape( + [2] + list(x.shape)) + + self.f_range_list = Manager().list() + self.f_mean_list = Manager().list() + + proc = list() + for role in range(3): + args = {'role': role, 'min': share(mi), 'max': share(ma), + 'mean': share(me), 'sample_num': share(sample_nums)} + p = Process(target=self.mean_normalize, kwargs=args) + + proc.append(p) + p.start() + + for p in proc: + p.join() + + f_r = aby3.reconstruct(np.array(self.f_range_list)) + f_m = aby3.reconstruct(np.array(self.f_mean_list)) + + plain_r, plain_m = mean_norm_naive(mat) + print("max error in featrue range:", np.max(np.abs(f_r - plain_r))) + print("max error in featrue mean:", np.max(np.abs(f_m - plain_m))) + + + MeanNormDemo().run() + ''' + helper = MpcLayerHelper("mean_normalize", **locals()) + + # dtype = helper.input_dtype() + dtype = 'int64' + + check_dtype(dtype, 'f_min', ['int64'], 'mean_normalize') + check_dtype(dtype, 'f_max', ['int64'], 'mean_normalize') + check_dtype(dtype, 'f_mean', ['int64'], 'mean_normalize') + check_dtype(dtype, 'sample_num', ['int64'], 'mean_normalize') + + f_range = helper.create_mpc_variable_for_type_inference(dtype=f_min.dtype) + f_mean_out= helper.create_mpc_variable_for_type_inference(dtype=f_min.dtype) + + op_type = 'mean_normalize' + + helper.append_op( + type='mpc_' + op_type, + inputs={ + "Min": f_min, + "Max": f_max, + "Mean": f_mean, + "SampleNum": sample_num, + }, + outputs={ + "Range": f_range, + "MeanOut": f_mean_out, + }, + attrs={ + # TODO: remove attr total_sample_num, reducing sample_num instead + "total_sample_num": total_sample_num, + }) + + return f_range, f_mean_out diff --git a/python/paddle_fl/mpc/tests/unittests/run_test_example.sh b/python/paddle_fl/mpc/tests/unittests/run_test_example.sh index 30ede58dfbbc401651ae6d832214e6b2d70f785c..7ee386f20a36292c05c6b2e3b1cd492ee7751f12 100644 --- a/python/paddle_fl/mpc/tests/unittests/run_test_example.sh +++ b/python/paddle_fl/mpc/tests/unittests/run_test_example.sh @@ -26,6 +26,7 @@ TEST_MODULES=("test_datautils_aby3" "test_op_conv" "test_op_pool" "test_op_metric" +"test_data_preprocessing" ) # run unittest diff --git a/python/paddle_fl/mpc/tests/unittests/test_data_preprocessing.py b/python/paddle_fl/mpc/tests/unittests/test_data_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..ec613c1d3902e096ec6e0281d4a94e66cb71c7bb --- /dev/null +++ b/python/paddle_fl/mpc/tests/unittests/test_data_preprocessing.py @@ -0,0 +1,124 @@ + +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module test data preprocessing. + +""" +import unittest +from multiprocessing import Manager + +import numpy as np +import paddle.fluid as fluid +import paddle_fl.mpc as pfl_mpc +import mpc_data_utils as mdu +import paddle_fl.mpc.data_utils.aby3 as aby3 + +import test_op_base + + +def mean_norm_naive(f_mat): + ma = np.amax(f_mat, axis=0) + mi = np.amin(f_mat, axis=0) + + return ma - mi, np.mean(f_mat, axis=0) + + +def gen_data(f_num, sample_nums): + f_mat = np.random.rand(np.sum(sample_nums), f_num) + + f_min, f_max, f_mean = [], [], [] + + prev_idx = 0 + + for n in sample_nums: + + i = prev_idx + j = i + n + + ma = np.amax(f_mat[i:j], axis=0) + mi = np.amin(f_mat[i:j], axis=0) + me = np.mean(f_mat[i:j], axis=0) + + f_min.append(mi) + f_max.append(ma) + f_mean.append(me) + + prev_idx += n + + f_min = np.array(f_min).reshape(sample_nums.size, f_num) + f_max = np.array(f_max).reshape(sample_nums.size, f_num) + f_mean = np.array(f_mean).reshape(sample_nums.size, f_num) + + return f_mat, f_min, f_max, f_mean + +class TestOpMeanNormalize(test_op_base.TestOpBase): + + def mean_normalize(self, **kwargs): + """ + mean_normalize op ut + :param kwargs: + :return: + """ + role = kwargs['role'] + + pfl_mpc.init("aby3", role, "localhost", self.server, int(self.port)) + + mi = pfl_mpc.data(name='mi', shape=self.input_size, dtype='int64') + ma = pfl_mpc.data(name='ma', shape=self.input_size, dtype='int64') + me = pfl_mpc.data(name='me', shape=self.input_size, dtype='int64') + sn = pfl_mpc.data(name='sn', shape=self.input_size, dtype='int64') + + out0, out1 = pfl_mpc.layers.mean_normalize(f_min=mi, f_max=ma, f_mean=me, sample_num=sn, total_sample_num=self.total_num) + + + exe = fluid.Executor(place=fluid.CPUPlace()) + + f_range, f_mean = exe.run(feed={'mi': kwargs['min'], + 'ma': kwargs['max'], 'me': kwargs['mean'], 'sn': kwargs['sample_num']},fetch_list=[out0, out1]) + + self.f_range_list.append(f_range) + self.f_mean_list.append(f_mean) + + def test_mean_normalize(self): + + f_nums = 100 + sample_nums = np.array(range(2, 10, 2)) + mat, mi, ma, me = gen_data(f_nums, sample_nums) + + self.input_size = [len(sample_nums), f_nums] + self.total_num = mat.shape[0] + + share = lambda x: np.array([x * mdu.mpc_one_share] * 2).astype('int64').reshape( + [2] + list(x.shape)) + + self.f_range_list = Manager().list() + self.f_mean_list = Manager().list() + + ret = self.multi_party_run(target=self.mean_normalize, + min=share(mi), max=share(ma), mean=share(me), sample_num=share(sample_nums)) + + self.assertEqual(ret[0], True) + + f_r = aby3.reconstruct(np.array(self.f_range_list)) + f_m = aby3.reconstruct(np.array(self.f_mean_list)) + + plain_r, plain_m = mean_norm_naive(mat) + self.assertTrue(np.allclose(f_r, plain_r, atol=1e-4)) + self.assertTrue(np.allclose(f_m, plain_m, atol=1e-3)) + + +if __name__ == '__main__': + unittest.main() +