From 63ae7e624c9585f915112136a723b164a513fd07 Mon Sep 17 00:00:00 2001 From: "He, Kai" Date: Wed, 16 Sep 2020 08:52:05 +0000 Subject: [PATCH] add mpc operator add, move mean_normalize to ml.py --- .../mpc_protocol/aby3_operators.h | 21 ++-- .../paddlefl_mpc/mpc_protocol/mpc_operators.h | 4 + .../operators/mpc_mean_normalize_op.h | 4 +- python/paddle_fl/mpc/layers/__init__.py | 3 - .../mpc/layers/data_preprocessing.py | 107 ------------------ python/paddle_fl/mpc/layers/ml.py | 98 +++++++++++++++- 6 files changed, 114 insertions(+), 123 deletions(-) delete mode 100644 python/paddle_fl/mpc/layers/data_preprocessing.py diff --git a/core/paddlefl_mpc/mpc_protocol/aby3_operators.h b/core/paddlefl_mpc/mpc_protocol/aby3_operators.h index 70e4543..411233d 100644 --- a/core/paddlefl_mpc/mpc_protocol/aby3_operators.h +++ b/core/paddlefl_mpc/mpc_protocol/aby3_operators.h @@ -319,17 +319,24 @@ public: auto a_tuple = from_tensor(in); auto a_ = std::get<0>(a_tuple).get(); + auto b_tuple = from_tensor(pos_info); + auto b_ = std::get<0>(b_tuple).get(); + auto out_tuple = from_tensor(out); auto out_ = std::get<0>(out_tuple).get(); - if (pos_info) { - auto b_tuple = from_tensor(pos_info); - auto b_ = std::get<0>(b_tuple).get(); + a_->max_pooling(out_, b_); + } + + void max(const Tensor* in, Tensor* out) override { + + auto a_tuple = from_tensor(in); + auto a_ = std::get<0>(a_tuple).get(); + + auto out_tuple = from_tensor(out); + auto out_ = std::get<0>(out_tuple).get(); - a_->max_pooling(out_, b_); - } else { - a_->max_pooling(out_, nullptr); - } + a_->max_pooling(out_, nullptr); } void inverse_square_root(const Tensor* in, Tensor* out) override { diff --git a/core/paddlefl_mpc/mpc_protocol/mpc_operators.h b/core/paddlefl_mpc/mpc_protocol/mpc_operators.h index 33f2f17..6b69cd1 100644 --- a/core/paddlefl_mpc/mpc_protocol/mpc_operators.h +++ b/core/paddlefl_mpc/mpc_protocol/mpc_operators.h @@ -82,6 +82,10 @@ public: // for filter in other shape, reshape input first virtual void max_pooling(const Tensor* in, Tensor* out, Tensor* pos_info) {} + // column wise max + // in shape [n, ...], out shape [1, ...] + virtual void max(const Tensor* in, Tensor* out) {} + virtual void inverse_square_root(const Tensor* in, Tensor* out) = 0; virtual void predicts_to_indices(const Tensor* in, diff --git a/core/paddlefl_mpc/operators/mpc_mean_normalize_op.h b/core/paddlefl_mpc/operators/mpc_mean_normalize_op.h index dcf9bba..2bd5b85 100644 --- a/core/paddlefl_mpc/operators/mpc_mean_normalize_op.h +++ b/core/paddlefl_mpc/operators/mpc_mean_normalize_op.h @@ -54,10 +54,10 @@ class MpcMeanNormalizationKernel : public MpcOpKernel { ->mpc_operators()->neg(min, &neg_min); mpc::MpcInstance::mpc_instance()->mpc_protocol() - ->mpc_operators()->max_pooling(&neg_min, &neg_min_global, nullptr); + ->mpc_operators()->max(&neg_min, &neg_min_global); mpc::MpcInstance::mpc_instance()->mpc_protocol() - ->mpc_operators()->max_pooling(max, &max_global, nullptr); + ->mpc_operators()->max(max, &max_global); range->mutable_data( framework::make_ddim({share_num, 1, feat_num}), context.GetPlace(), 0); diff --git a/python/paddle_fl/mpc/layers/__init__.py b/python/paddle_fl/mpc/layers/__init__.py index aebd09d..3f6b0f2 100644 --- a/python/paddle_fl/mpc/layers/__init__.py +++ b/python/paddle_fl/mpc/layers/__init__.py @@ -37,8 +37,6 @@ from . import rnn from .rnn import * from . import metric_op from .metric_op import * -from . import data_preprocessing -from .data_preprocessing import * __all__ = [] __all__ += basic.__all__ @@ -48,4 +46,3 @@ __all__ += ml.__all__ __all__ += compare.__all__ __all__ += conv.__all__ __all__ += metric_op.__all__ -__all__ += data_preprocessing.__all__ diff --git a/python/paddle_fl/mpc/layers/data_preprocessing.py b/python/paddle_fl/mpc/layers/data_preprocessing.py deleted file mode 100644 index 813f85f..0000000 --- a/python/paddle_fl/mpc/layers/data_preprocessing.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -mpc data preprocessing op layers. -""" -from paddle.fluid.data_feeder import check_type, check_dtype -from ..framework import check_mpc_variable_and_dtype -from ..mpc_layer_helper import MpcLayerHelper -from .math import reduce_sum - -__all__ = ['mean_normalize'] - -def mean_normalize(f_min, f_max, f_mean, sample_num): - ''' - Mean normalization is a method used to normalize the range of independent - variables or features of data. - Refer to: - https://en.wikipedia.org/wiki/Feature_scaling#Mean_normalization - - Args: - f_min (Variable): A 2-D tensor with shape [P, N], where P is the party - num and N is the feature num. Each row contains the - local min feature val of N features. - f_max (Variable): A 2-D tensor with shape [P, N], where P is the party - num and N is the feature num. Each row contains the - local max feature val of N features. - f_mean (Variable): A 2-D tensor with shape [P, N], where P is the party - num and N is the feature num. Each row contains the - local min feature val of N features. - sample_num (Variable): A 1-D tensor with shape [P], where P is the - party num. Each element contains sample num - of party_i. - - Returns: - f_range (Variable): A 1-D tensor with shape [N], where N is the - feature num. Each element contains global - range of feature_i. - f_mean_out (Variable): A 1-D tensor with shape [N], where N is the - feature num. Each element contains global - range of feature_i. - Examples: - .. code-block:: python - import paddle_fl.mpc as pfl_mpc - - pfl_mpc.init("aby3", role, "localhost", redis_server, redis_port) - - # 2 for share, 4 for 4 party, 100 for feat_num - input_size = [2, 4, 100] - - mi = pfl_mpc.data(name='mi', shape=input_size, dtype='int64') - ma = pfl_mpc.data(name='ma', shape=input_size, dtype='int64') - me = pfl_mpc.data(name='me', shape=input_size, dtype='int64') - sn = pfl_mpc.data(name='sn', shape=input_size[:-1], dtype='int64') - - out0, out1 = pfl_mpc.layers.mean_normalize(f_min=mi, f_max=ma, - f_mean=me, sample_num=sn) - - exe = fluid.Executor(place=fluid.CPUPlace()) - - # feed encrypted data - f_range, f_mean = exe.run(feed={'mi': f_min, 'ma': f_max, - 'me': f_mean, 'sn': sample_num}, fetch_list=[out0, out1]) - ''' - helper = MpcLayerHelper("mean_normalize", **locals()) - - # dtype = helper.input_dtype() - dtype = 'int64' - - check_dtype(dtype, 'f_min', ['int64'], 'mean_normalize') - check_dtype(dtype, 'f_max', ['int64'], 'mean_normalize') - check_dtype(dtype, 'f_mean', ['int64'], 'mean_normalize') - check_dtype(dtype, 'sample_num', ['int64'], 'mean_normalize') - - f_range = helper.create_mpc_variable_for_type_inference(dtype=f_min.dtype) - f_mean_out= helper.create_mpc_variable_for_type_inference(dtype=f_min.dtype) - - total_num = reduce_sum(sample_num) - - op_type = 'mean_normalize' - - helper.append_op( - type='mpc_' + op_type, - inputs={ - "Min": f_min, - "Max": f_max, - "Mean": f_mean, - "SampleNum": sample_num, - "TotalNum": total_num, - }, - outputs={ - "Range": f_range, - "MeanOut": f_mean_out, - }, - ) - - return f_range, f_mean_out diff --git a/python/paddle_fl/mpc/layers/ml.py b/python/paddle_fl/mpc/layers/ml.py index c45953a..2085c94 100644 --- a/python/paddle_fl/mpc/layers/ml.py +++ b/python/paddle_fl/mpc/layers/ml.py @@ -37,6 +37,7 @@ __all__ = [ 'pool2d', 'batch_norm', 'reshape', + 'mean_normalize', ] @@ -612,7 +613,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): helper = MpcLayerHelper("reshape2", **locals()) _helper = LayerHelper("reshape2", **locals()) - + def get_new_shape_tensor(list_shape): new_shape_tensor = [] for dim in list_shape: @@ -625,7 +626,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out) new_shape_tensor.append(temp_out) return new_shape_tensor - + def get_attr_shape(list_shape): unk_dim_idx = -1 attrs_shape = [] @@ -662,13 +663,13 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): assert len(shape) > 0, ("The size of 'shape' in reshape can't be zero, " "but received %s." % len(shape)) attrs["shape"] = get_attr_shape(shape) - + if utils._contain_var(shape): inputs['ShapeTensor'] = get_new_shape_tensor(shape) elif isinstance(actual_shape, Variable): actual_shape.stop_gradient = True inputs["Shape"] = actual_shape - + out = x if inplace else helper.create_mpc_variable_for_type_inference( dtype=x.dtype) x_shape = helper.create_mpc_variable_for_type_inference(dtype=x.dtype) @@ -680,3 +681,92 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None): "XShape": x_shape}) return helper.append_activation(out) + + +def mean_normalize(f_min, f_max, f_mean, sample_num): + ''' + Mean normalization is a method used to normalize the range of independent + variables or features of data. + Refer to: + https://en.wikipedia.org/wiki/Feature_scaling#Mean_normalization + + Args: + f_min (Variable): A 2-D tensor with shape [P, N], where P is the party + num and N is the feature num. Each row contains the + local min feature val of N features. + f_max (Variable): A 2-D tensor with shape [P, N], where P is the party + num and N is the feature num. Each row contains the + local max feature val of N features. + f_mean (Variable): A 2-D tensor with shape [P, N], where P is the party + num and N is the feature num. Each row contains the + local min feature val of N features. + sample_num (Variable): A 1-D tensor with shape [P], where P is the + party num. Each element contains sample num + of party_i. + + Returns: + f_range (Variable): A 1-D tensor with shape [N], where N is the + feature num. Each element contains global + range of feature_i. + f_mean_out (Variable): A 1-D tensor with shape [N], where N is the + feature num. Each element contains global + range of feature_i. + Examples: + .. code-block:: python + import paddle_fl.mpc as pfl_mpc + + pfl_mpc.init("aby3", role, "localhost", redis_server, redis_port) + + # 2 for share, 4 for 4 party, 100 for feat_num + input_size = [2, 4, 100] + + mi = pfl_mpc.data(name='mi', shape=input_size, dtype='int64') + ma = pfl_mpc.data(name='ma', shape=input_size, dtype='int64') + me = pfl_mpc.data(name='me', shape=input_size, dtype='int64') + sn = pfl_mpc.data(name='sn', shape=input_size[:-1], dtype='int64') + + out0, out1 = pfl_mpc.layers.mean_normalize(f_min=mi, f_max=ma, + f_mean=me, sample_num=sn) + + exe = fluid.Executor(place=fluid.CPUPlace()) + + # feed encrypted data + f_range, f_mean = exe.run(feed={'mi': f_min, 'ma': f_max, + 'me': f_mean, 'sn': sample_num}, fetch_list=[out0, out1]) + ''' + helper = MpcLayerHelper("mean_normalize", **locals()) + + # dtype = helper.input_dtype() + dtype = 'int64' + + check_dtype(dtype, 'f_min', ['int64'], 'mean_normalize') + check_dtype(dtype, 'f_max', ['int64'], 'mean_normalize') + check_dtype(dtype, 'f_mean', ['int64'], 'mean_normalize') + check_dtype(dtype, 'sample_num', ['int64'], 'mean_normalize') + + f_range = helper.create_mpc_variable_for_type_inference(dtype=f_min.dtype) + f_mean_out= helper.create_mpc_variable_for_type_inference(dtype=f_min.dtype) + + # to avoid circular dependencies + from .math import reduce_sum + + total_num = reduce_sum(sample_num) + + op_type = 'mean_normalize' + + helper.append_op( + type='mpc_' + op_type, + inputs={ + "Min": f_min, + "Max": f_max, + "Mean": f_mean, + "SampleNum": sample_num, + "TotalNum": total_num, + }, + outputs={ + "Range": f_range, + "MeanOut": f_mean_out, + }, + ) + + return f_range, f_mean_out -- GitLab