未验证 提交 3105a2b1 编写于 作者: J jed 提交者: GitHub

Merge pull request #122 from kaih70/master

add mean normalize
......@@ -328,6 +328,17 @@ public:
a_->max_pooling(out_, b_);
}
void max(const Tensor* in, Tensor* out) override {
auto a_tuple = from_tensor(in);
auto a_ = std::get<0>(a_tuple).get();
auto out_tuple = from_tensor(out);
auto out_ = std::get<0>(out_tuple).get();
a_->max_pooling(out_, nullptr);
}
void inverse_square_root(const Tensor* in, Tensor* out) override {
auto x_tuple = from_tensor(in);
auto x_ = std::get<0>(x_tuple).get();
......@@ -377,6 +388,20 @@ public:
FixedTensor::calc_precision_recall(in, &out_);
}
void div(const Tensor *lhs, const Tensor *rhs, Tensor *out) override {
auto lhs_tuple = from_tensor(lhs);
auto rhs_tuple = from_tensor(rhs);
auto out_tuple = from_tensor(out);
auto lhs_ = std::get<0>(lhs_tuple).get();
auto rhs_ = std::get<0>(rhs_tuple).get();
auto out_ = std::get<0>(out_tuple).get();
lhs_->long_div(rhs_, out_);
}
private:
template <typename T>
std::tuple<
......
......@@ -82,6 +82,10 @@ public:
// for filter in other shape, reshape input first
virtual void max_pooling(const Tensor* in, Tensor* out, Tensor* pos_info) {}
// column wise max
// in shape [n, ...], out shape [1, ...]
virtual void max(const Tensor* in, Tensor* out) {}
virtual void inverse_square_root(const Tensor* in, Tensor* out) = 0;
virtual void predicts_to_indices(const Tensor* in,
......@@ -93,6 +97,8 @@ public:
Tensor* out) = 0;
virtual void calc_precision_recall(const Tensor* tp_fp_fn, Tensor* out) = 0;
virtual void div(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0;
};
} // mpc
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mpc_mean_normalize_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include <string>
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class MpcMeanNormalizationOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Min"), true,
platform::errors::InvalidArgument(
"Input(Min) should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Max"), true,
platform::errors::InvalidArgument("Input(Max) should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Mean"), true,
platform::errors::InvalidArgument("Input(Mean) should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("SampleNum"), true,
platform::errors::InvalidArgument("Input(Sample) should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("TotalNum"), true,
platform::errors::InvalidArgument("Input(TotalNum) should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Range"), true,
platform::errors::InvalidArgument(
"Output(Range) should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("MeanOut"), true,
platform::errors::InvalidArgument(
"Output(Meanor) should not be null."));
auto min_dims = ctx->GetInputDim("Min");
auto max_dims = ctx->GetInputDim("Max");
auto mean_dims = ctx->GetInputDim("Mean");
auto sample_num_dims = ctx->GetInputDim("SampleNum");
auto total_num_dims = ctx->GetInputDim("TotalNum");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(min_dims, max_dims,
platform::errors::InvalidArgument(
"The dimension of Input(Min) and "
"Input(Max) should be the same."
"But received (%d) != (%d)",
min_dims, max_dims));
PADDLE_ENFORCE_EQ(min_dims, mean_dims,
platform::errors::InvalidArgument(
"The dimension of Input(Min) and "
"Input(Max) should be the same."
"But received (%d) != (%d)",
min_dims, mean_dims));
PADDLE_ENFORCE_EQ(
min_dims.size(), 3,
platform::errors::InvalidArgument(
"The dimension of Input(Min) should be equal to 3 "
"(share_num, party_num, feature_num). But received (%d)",
min_dims.size()));
PADDLE_ENFORCE_EQ(
sample_num_dims.size(), 2,
platform::errors::InvalidArgument(
"The dimension of Input(SampleNum) should be equal to 2 "
"(share_num, party_num). But received (%d)",
sample_num_dims.size()));
PADDLE_ENFORCE_EQ(
sample_num_dims[1], min_dims[1],
platform::errors::InvalidArgument(
"The party num of Input(SampleNum) and Input(Min) "
"should be equal But received (%d) != (%d)",
sample_num_dims[1], min_dims[1]));
PADDLE_ENFORCE_EQ(
total_num_dims.size(), 2,
platform::errors::InvalidArgument(
"The dimension of Input(TotalNum) "
"should be 2, But received (%d) != (%d)",
total_num_dims.size(), 2));
PADDLE_ENFORCE_EQ(
sample_num_dims[0], total_num_dims[0],
platform::errors::InvalidArgument(
"The share num of Input(SampleNum) and Input(TotalNum) "
"should be equal But received (%d) != (%d)",
sample_num_dims[0], total_num_dims[0]));
PADDLE_ENFORCE_EQ(
total_num_dims[1], 1,
platform::errors::InvalidArgument(
"The shape of Input(TotalNum) "
"should be [share_num, 1] But dims[1] received (%d) != (%d)",
total_num_dims[1], 1));
}
ctx->SetOutputDim("Range", {mean_dims[0], mean_dims[2]});
ctx->SetOutputDim("MeanOut", {mean_dims[0], mean_dims[2]});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Min"),
ctx.device_context());
}
};
class MpcMeanNormalizationOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Min",
"(Tensor, default Tensor<int64_t>) A 2-D tensor with shape [P, N], "
"where P is the party num and N is the feature num. Each row contains "
" the local min feature val of N features.");
AddInput("Max",
"(Tensor, default Tensor<int64_t>) A 2-D tensor with shape [P, N], "
"where P is the party num and N is the feature num. Each row contains "
" the local max feature val of N features.");
AddInput("Mean",
"(Tensor, default Tensor<int64_t>) A 2-D tensor with shape [P, N], "
"where P is the party num and N is the feature num. Each row contains "
" the local mean feature val of N features.");
AddInput("SampleNum",
"(Tensor, default Tensor<int64_t>) A 1-D tensor with shape [P], "
"where P is the party num. Each element contains "
"sample num of party_i.");
AddInput("TotalNum",
"(Tensor, default Tensor<int64_t>) A 1-D tensor with shape [1], "
"Element contains sum of sample num of party_i.");
AddOutput("Range",
"(Tensor, default Tensor<int64_t>) A 1-D tensor with shape [N], "
"where N is the feature num. Each element contains "
"global range of feature_i.");
AddOutput("MeanOut",
"(Tensor, default Tensor<int64_t>) A 1-D tensor with shape [N], "
"where N is the feature num. Each element contains "
"global mean of feature_i.");
AddComment(R"DOC(
Mean normalization Operator.
When given Input(Min), Input(Max), Input(Mean), Input(SampleNum) and Input(TotalNum)
this operator can be used to compute global range and mean for further feature
scaling.
Output(Range) is the global range of all features.
Output(MeanOut) is the global mean of all features.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
mpc_mean_normalize, ops::MpcMeanNormalizationOp, ops::MpcMeanNormalizationOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
mpc_mean_normalize,
ops::MpcMeanNormalizationKernel<paddle::platform::CPUPlace, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class MpcMeanNormalizationKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext& context) const override {
const Tensor* min = context.Input<Tensor>("Min");
const Tensor* max = context.Input<Tensor>("Max");
const Tensor* mean = context.Input<Tensor>("Mean");
const Tensor* sample_num = context.Input<Tensor>("SampleNum");
const Tensor* total_num = context.Input<Tensor>("TotalNum");
Tensor* range = context.Output<Tensor>("Range");
Tensor* mean_out = context.Output<Tensor>("MeanOut");
int share_num = min->dims()[0];
int party_num = min->dims()[1];
int feat_num = min->dims()[2];
Tensor neg_min;
neg_min.mutable_data<T>(min->dims(), context.GetPlace(), 0);
Tensor neg_min_global;
Tensor max_global;
neg_min_global.mutable_data<T>(
framework::make_ddim({share_num, 1, feat_num}), context.GetPlace(), 0);
max_global.mutable_data<T>(
framework::make_ddim({share_num, 1, feat_num}), context.GetPlace(), 0);
mpc::MpcInstance::mpc_instance()->mpc_protocol()
->mpc_operators()->neg(min, &neg_min);
mpc::MpcInstance::mpc_instance()->mpc_protocol()
->mpc_operators()->max(&neg_min, &neg_min_global);
mpc::MpcInstance::mpc_instance()->mpc_protocol()
->mpc_operators()->max(max, &max_global);
range->mutable_data<T>(
framework::make_ddim({share_num, 1, feat_num}), context.GetPlace(), 0);
mpc::MpcInstance::mpc_instance()->mpc_protocol()
->mpc_operators()->add(&max_global, &neg_min_global, range);
range->mutable_data<T>(
framework::make_ddim({share_num, feat_num}), context.GetPlace(), 0);
Tensor sample_num_;
sample_num_.ShareDataWith(*sample_num);
sample_num_.mutable_data<T>(
framework::make_ddim({share_num, 1, party_num}), context.GetPlace(), 0);
mean_out->mutable_data<T>(
framework::make_ddim({share_num, 1, feat_num}), context.GetPlace(), 0);
mpc::MpcInstance::mpc_instance()->mpc_protocol()
->mpc_operators()->matmul(&sample_num_, mean, mean_out);
mean_out->mutable_data<T>(
framework::make_ddim({share_num, feat_num}), context.GetPlace(), 0);
Tensor total_num_;
total_num_.mutable_data<T>(
framework::make_ddim({share_num, feat_num}), context.GetPlace(), 0);
// broadcasting total_num to shape [share_num, feat_num]
for (int i = 0; i < share_num; ++i) {
std::fill(total_num_.data<T>() + i * feat_num,
total_num_.data<T>() + (i + 1) * feat_num,
total_num->data<T>()[i]);
}
mpc::MpcInstance::mpc_instance()->mpc_protocol()
->mpc_operators()->div(mean_out, &total_num_, mean_out);
}
};
} // namespace operators
} // namespace paddle
## Instructions for PaddleFL-MPC Mean Normalize Demo
This document introduces how to run Mean Normalize demo based on Paddle-MPC,
which is single machine demo.
### Running on Single Machine
#### (1). Prepare Data
Create a empty dir for data, and modify `data_path` in `process_data.py`,
default dir path is `./data`.
Then run the script with command `python prepare.py` to generate random data
for demo, which is dumped by numpy and named `feature_data.{i}.npy` located
in `data_path`. Otherwise generate your own data, move them to `data_path`,
name as the same way, and modify corresponding meta info in `prepare.py`.
Encrypted data files of feature statstics would be generated and saved in
`data_path` directory. Different suffix names are used for these files to
indicate the ownership of different data source and computation parties.
For instance, a file named `feature_max.1.part2` means it contains the max
feature values from data owner 1 and needs to be feed to computing party 2.
#### (2). Launch Demo with A Shell Script
You should set the env params as follow:
```
export PYTHON=/yor/python
export PATH_TO_REDIS_BIN=/path/to/redis_bin
export LOCALHOST=/your/localhost
export REDIS_PORT=/your/redis/port
```
Launch demo with the `run_standalone.sh` script. The concrete command is:
```bash
bash ../run_standalone.sh mean_normalize_demo.py
```
The ciphertext result of global feature range and feature mean will be save in
`data_path` directory, named `result.part{i}`.
#### (3). Decrypt Data
Finally, using `decrypt_data()` in `process_data.py` script, this demo would
decrypt and returns the result, which can be used to rescale local feature data
by all data owners respectively.
```python
import prepare
import process_data
# 0 for f_range, 1 for f_mean
# use decrypted global f_range and f_mean to rescaling local feature data
res = process_data.decrypt_data(prepare.data_path + 'result', (2, prepare.feat_width, ))
```
Or use `decrypt_and_rescale.py` to decrypt, rescale the feature data which has
been saved in `feature_data.{i}.npy`, and dump the normalized data to
`normalized_data.{i}.npy` which is located in `data_path`.
Also, `verify.py` could be used to calculate error of `f_range` and `f_mean`
between direct plaintext numpy calculation and mpc mean normalize.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Decrypt and rescale for mean normalize demo.
"""
import sys
import numpy as np
import process_data
import prepare
data_path = prepare.data_path
# 0 for f_range, 1 for f_mean
# use decrypted global f_range and f_mean to rescaling local feature data
res = process_data.decrypt_data(data_path + 'result', (2, prepare.feat_width, ))
party = sys.argv[1]
input = np.load(data_path + 'feature_data.' + party + '.npy')
output = (input - res[1]) / res[0]
np.save(data_path + 'normalized_data.' + party, output)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Mean normalize demo.
"""
import sys
import numpy as np
import paddle.fluid as fluid
import paddle_fl.mpc as pfl_mpc
import paddle_fl.mpc.data_utils.aby3 as aby3
import prepare
import process_data
role, server, port = sys.argv[1], sys.argv[2], sys.argv[3]
role, port = int(role), int(port)
share_num = aby3.ABY3_SHARE_DIM
party_num = len(prepare.sample_nums)
feat_num = prepare.feat_width
data_path = prepare.data_path
def get_shares(path):
'''
collect encrypted feature stats from all data owners
'''
data = []
for i in range(party_num):
reader = aby3.load_aby3_shares(path + '.' + str(i),
id=role, shape=(feat_num,))
data.append([x for x in reader()])
data = np.array(data).reshape([party_num, share_num, feat_num])
return np.transpose(data, axes=[1, 0, 2])
def get_sample_num(path):
'''
get encrypted sample nums
'''
reader = aby3.load_aby3_shares(path,
id=role, shape=(party_num,))
for n in reader():
return n
f_max = get_shares(data_path + 'feature_max')
f_min = get_shares(data_path + 'feature_min')
f_mean = get_shares(data_path + 'feature_mean')
sample_num = get_sample_num(data_path + 'sample_num')
pfl_mpc.init("aby3", int(role), "localhost", server, int(port))
shape = [party_num, feat_num]
mi = pfl_mpc.data(name='mi', shape=shape, dtype='int64')
ma = pfl_mpc.data(name='ma', shape=shape, dtype='int64')
me = pfl_mpc.data(name='me', shape=shape, dtype='int64')
sn = pfl_mpc.data(name='sn', shape=shape[:-1], dtype='int64')
out0, out1 = pfl_mpc.layers.mean_normalize(f_min=mi, f_max=ma,
f_mean=me, sample_num=sn)
exe = fluid.Executor(place=fluid.CPUPlace())
f_range, f_mean = exe.run(feed={'mi': f_min, 'ma': f_max, 'me': f_mean,
'sn': sample_num},fetch_list=[out0, out1])
result = np.transpose(np.array([f_range, f_mean]), axes=[1, 0, 2])
result_file = data_path + "result.part{}".format(role)
with open(result_file, 'wb') as f:
f.write(result.tostring())
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Prepare data for mean normalize demo.
"""
import numpy as np
import process_data
from paddle_fl.mpc.data_utils import aby3
data_path = process_data.data_path
feat_width = 100
# assume data owner i has sample_nums[i] samples
sample_nums = [1, 2, 3, 4]
def gen_random_data():
for i, num in enumerate(sample_nums):
suffix = '.' + str(i)
f_mat = np.random.rand(num, feat_width)
np.save(data_path + 'feature_data' + suffix, f_mat)
process_data.generate_encrypted_data(i, f_mat)
aby3.save_aby3_shares(process_data.encrypted_data(np.array(sample_nums)),
data_path + 'sample_num')
if __name__ == "__main__":
gen_random_data()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Process data for mean normalize demo.
"""
import numpy as np
import six
import os
import paddle
from paddle_fl.mpc.data_utils import aby3
data_path = './data/'
def encrypted_data(data):
"""
feature stat reader
"""
def func():
yield aby3.make_shares(data)
return func
def generate_encrypted_data(party_id, f_mat):
"""
generate encrypted data from feature matrix (np.array)
"""
f_max = np.amax(f_mat, axis=0)
f_min = np.amin(f_mat, axis=0)
f_mean = np.mean(f_mat, axis=0)
suffix = '.' + str(party_id)
aby3.save_aby3_shares(encrypted_data(f_max),
data_path + "feature_max" + suffix)
aby3.save_aby3_shares(encrypted_data(f_min),
data_path + "feature_min" + suffix)
aby3.save_aby3_shares(encrypted_data(f_mean),
data_path + "feature_mean" + suffix)
def decrypt_data(filepath, shape):
"""
load the encrypted data and reconstruct
"""
part_readers = []
for id in six.moves.range(3):
part_readers.append(
aby3.load_aby3_shares(
filepath, id=id, shape=shape))
aby3_share_reader = paddle.reader.compose(part_readers[0], part_readers[1],
part_readers[2])
for instance in aby3_share_reader():
p = aby3.reconstruct(np.array(instance))
return p
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Verification for mean normalize demo.
"""
import prepare
import process_data
import numpy as np
import paddle_fl.mpc.data_utils.aby3 as aby3
# 0 for f_range, 1 for f_mean
# use decrypted global f_range and f_mean to rescaling local feature data
res = process_data.decrypt_data(prepare.data_path + 'result', (2, prepare.feat_width, ))
# reconstruct plaintext global data to verify
row, col = sum(prepare.sample_nums), prepare.feat_width
plain_mat = np.empty((row, col))
row = 0
for i,num in enumerate(prepare.sample_nums):
m = np.load(prepare.data_path+'feature_data.' + str(i) + '.npy')
plain_mat[row:row+num] = m
row += num
def mean_normalize(f_mat):
'''
get plain text f_range & f_mean
'''
ma = np.amax(f_mat, axis=0)
mi = np.amin(f_mat, axis=0)
return ma - mi, np.mean(f_mat, axis=0)
plain_range, plain_mean = mean_normalize(plain_mat)
print("max error in featrue range:", np.max(np.abs(res[0] - plain_range)))
print("max error in featrue mean:", np.max(np.abs(res[1] - plain_mean)))
......@@ -18,6 +18,7 @@ mpc math op layers.
from ..framework import MpcVariable
from ..framework import check_mpc_variable_and_dtype
from ..mpc_layer_helper import MpcLayerHelper
from .ml import reshape
__all__ = [
'mean',
......@@ -125,7 +126,7 @@ def square_error_cost(input, label):
square_out = helper.create_mpc_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='mpc_square',
type='mpc_square',
inputs={'X': [minus_out]},
outputs={'Out': [square_out]})
return square_out
......@@ -158,14 +159,14 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
Returns:
out(MpcVariable): (Tensor) The output of mean op
Examples:
Examples:
.. code-block:: python
import paddle_fl.mpc as pfl_mpc
pfl_mpc.init("aby3", int(args.role), "localhost", args.server, int(args.port))
data_1 = pfl_mpc.data(name='x', shape=[3, 3], dtype='int64')
pfl_mpc.layers.reshape(data_1, [1, 2]) # shape: [2, 1, 1]
pfl_mpc.layers.reshape(data_1, [1, 2]) # shape: [2, 1, 1]
# data_1 = np.full(shape=(3, 4), fill_value=2)
# reduce_sum: 24
"""
......@@ -178,7 +179,7 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
"'dim' should not contain 0, because dim[0] is share number."
)
else:
dim = [i for i in range(len(input.shape))][1:]
dim = [i for i in range(len(input.shape))][1:]
attrs = {
'dim': dim,
......@@ -194,6 +195,8 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
inputs={'X': input},
outputs={'Out': out},
attrs=attrs)
if out.shape == (2,):
out = reshape(out, list(out.shape) + [1])
return out
......@@ -37,6 +37,7 @@ __all__ = [
'pool2d',
'batch_norm',
'reshape',
'mean_normalize',
]
......@@ -612,7 +613,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
helper = MpcLayerHelper("reshape2", **locals())
_helper = LayerHelper("reshape2", **locals())
def get_new_shape_tensor(list_shape):
new_shape_tensor = []
for dim in list_shape:
......@@ -625,7 +626,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out)
new_shape_tensor.append(temp_out)
return new_shape_tensor
def get_attr_shape(list_shape):
unk_dim_idx = -1
attrs_shape = []
......@@ -662,13 +663,13 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
assert len(shape) > 0, ("The size of 'shape' in reshape can't be zero, "
"but received %s." % len(shape))
attrs["shape"] = get_attr_shape(shape)
if utils._contain_var(shape):
inputs['ShapeTensor'] = get_new_shape_tensor(shape)
elif isinstance(actual_shape, Variable):
actual_shape.stop_gradient = True
inputs["Shape"] = actual_shape
out = x if inplace else helper.create_mpc_variable_for_type_inference(
dtype=x.dtype)
x_shape = helper.create_mpc_variable_for_type_inference(dtype=x.dtype)
......@@ -680,3 +681,92 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
"XShape": x_shape})
return helper.append_activation(out)
def mean_normalize(f_min, f_max, f_mean, sample_num):
'''
Mean normalization is a method used to normalize the range of independent
variables or features of data.
Refer to:
https://en.wikipedia.org/wiki/Feature_scaling#Mean_normalization
Args:
f_min (Variable): A 2-D tensor with shape [P, N], where P is the party
num and N is the feature num. Each row contains the
local min feature val of N features.
f_max (Variable): A 2-D tensor with shape [P, N], where P is the party
num and N is the feature num. Each row contains the
local max feature val of N features.
f_mean (Variable): A 2-D tensor with shape [P, N], where P is the party
num and N is the feature num. Each row contains the
local min feature val of N features.
sample_num (Variable): A 1-D tensor with shape [P], where P is the
party num. Each element contains sample num
of party_i.
Returns:
f_range (Variable): A 1-D tensor with shape [N], where N is the
feature num. Each element contains global
range of feature_i.
f_mean_out (Variable): A 1-D tensor with shape [N], where N is the
feature num. Each element contains global
range of feature_i.
Examples:
.. code-block:: python
import paddle_fl.mpc as pfl_mpc
pfl_mpc.init("aby3", role, "localhost", redis_server, redis_port)
# 2 for share, 4 for 4 party, 100 for feat_num
input_size = [2, 4, 100]
mi = pfl_mpc.data(name='mi', shape=input_size, dtype='int64')
ma = pfl_mpc.data(name='ma', shape=input_size, dtype='int64')
me = pfl_mpc.data(name='me', shape=input_size, dtype='int64')
sn = pfl_mpc.data(name='sn', shape=input_size[:-1], dtype='int64')
out0, out1 = pfl_mpc.layers.mean_normalize(f_min=mi, f_max=ma,
f_mean=me, sample_num=sn)
exe = fluid.Executor(place=fluid.CPUPlace())
# feed encrypted data
f_range, f_mean = exe.run(feed={'mi': f_min, 'ma': f_max,
'me': f_mean, 'sn': sample_num}, fetch_list=[out0, out1])
'''
helper = MpcLayerHelper("mean_normalize", **locals())
# dtype = helper.input_dtype()
dtype = 'int64'
check_dtype(dtype, 'f_min', ['int64'], 'mean_normalize')
check_dtype(dtype, 'f_max', ['int64'], 'mean_normalize')
check_dtype(dtype, 'f_mean', ['int64'], 'mean_normalize')
check_dtype(dtype, 'sample_num', ['int64'], 'mean_normalize')
f_range = helper.create_mpc_variable_for_type_inference(dtype=f_min.dtype)
f_mean_out= helper.create_mpc_variable_for_type_inference(dtype=f_min.dtype)
# to avoid circular dependencies
from .math import reduce_sum
total_num = reduce_sum(sample_num)
op_type = 'mean_normalize'
helper.append_op(
type='mpc_' + op_type,
inputs={
"Min": f_min,
"Max": f_max,
"Mean": f_mean,
"SampleNum": sample_num,
"TotalNum": total_num,
},
outputs={
"Range": f_range,
"MeanOut": f_mean_out,
},
)
return f_range, f_mean_out
......@@ -140,13 +140,13 @@ class OpTest(unittest.TestCase):
target = kwargs['target']
partys = []
parties = []
for role in range(self.party_num):
kwargs.update({'role': role})
partys.append(Aby3Process(target=target, kwargs=kwargs))
partys[-1].start()
for party in partys:
parties.append(Aby3Process(target=target, kwargs=kwargs))
parties[-1].start()
for party in parties:
party.join()
if party.exception:
return party.exception
......
......@@ -26,6 +26,7 @@ TEST_MODULES=("test_datautils_aby3"
"test_op_conv"
"test_op_pool"
"test_op_metric"
"test_data_preprocessing"
"test_op_reshape"
"test_op_reduce_sum"
)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module test data preprocessing.
"""
import unittest
from multiprocessing import Manager
import numpy as np
import paddle.fluid as fluid
import paddle_fl.mpc as pfl_mpc
import mpc_data_utils as mdu
import paddle_fl.mpc.data_utils.aby3 as aby3
import test_op_base
def mean_norm_naive(f_mat):
ma = np.amax(f_mat, axis=0)
mi = np.amin(f_mat, axis=0)
return ma - mi, np.mean(f_mat, axis=0)
def gen_data(f_num, sample_nums):
f_mat = np.random.rand(np.sum(sample_nums), f_num)
f_min, f_max, f_mean = [], [], []
prev_idx = 0
for n in sample_nums:
i = prev_idx
j = i + n
ma = np.amax(f_mat[i:j], axis=0)
mi = np.amin(f_mat[i:j], axis=0)
me = np.mean(f_mat[i:j], axis=0)
f_min.append(mi)
f_max.append(ma)
f_mean.append(me)
prev_idx += n
f_min = np.array(f_min).reshape(sample_nums.size, f_num)
f_max = np.array(f_max).reshape(sample_nums.size, f_num)
f_mean = np.array(f_mean).reshape(sample_nums.size, f_num)
return f_mat, f_min, f_max, f_mean
class TestOpMeanNormalize(test_op_base.TestOpBase):
def mean_normalize(self, **kwargs):
"""
mean_normalize op ut
:param kwargs:
:return:
"""
role = kwargs['role']
pfl_mpc.init("aby3", role, "localhost", self.server, int(self.port))
mi = pfl_mpc.data(name='mi', shape=self.input_size, dtype='int64')
ma = pfl_mpc.data(name='ma', shape=self.input_size, dtype='int64')
me = pfl_mpc.data(name='me', shape=self.input_size, dtype='int64')
sn = pfl_mpc.data(name='sn', shape=self.input_size[:-1], dtype='int64')
out0, out1 = pfl_mpc.layers.mean_normalize(f_min=mi,
f_max=ma, f_mean=me, sample_num=sn)
exe = fluid.Executor(place=fluid.CPUPlace())
f_range, f_mean = exe.run(feed={'mi': kwargs['min'],
'ma': kwargs['max'], 'me': kwargs['mean'], 'sn': kwargs['sample_num']},fetch_list=[out0, out1])
self.f_range_list.append(f_range)
self.f_mean_list.append(f_mean)
def test_mean_normalize(self):
f_nums = 100
sample_nums = np.array(range(2, 10, 2))
mat, mi, ma, me = gen_data(f_nums, sample_nums)
self.input_size = [len(sample_nums), f_nums]
share = lambda x: np.array([x * mdu.mpc_one_share] * 2).astype('int64').reshape(
[2] + list(x.shape))
self.f_range_list = Manager().list()
self.f_mean_list = Manager().list()
ret = self.multi_party_run(target=self.mean_normalize,
min=share(mi), max=share(ma), mean=share(me), sample_num=share(sample_nums))
self.assertEqual(ret[0], True)
f_r = aby3.reconstruct(np.array(self.f_range_list))
f_m = aby3.reconstruct(np.array(self.f_mean_list))
plain_r, plain_m = mean_norm_naive(mat)
self.assertTrue(np.allclose(f_r, plain_r, atol=1e-4))
self.assertTrue(np.allclose(f_m, plain_m, atol=1e-4))
if __name__ == '__main__':
unittest.main()
......@@ -77,13 +77,13 @@ class TestOpBase(unittest.TestCase):
"""
target = kwargs['target']
parties = []
for role in range(self.party_num):
kwargs.update({'role': role})
party = Aby3Process(target=target, kwargs=kwargs)
party.start()
if role == self.party_num - 1:
party.join()
if party.exception:
return party.exception
else:
return (True,)
parties.append(Aby3Process(target=target, kwargs=kwargs))
parties[-1].start()
for party in parties:
party.join()
if party.exception:
return party.exception
return (True,)
......@@ -19,10 +19,10 @@ import unittest
from multiprocessing import Manager
import numpy as np
import test_op_base
from op_test import OpTest
import paddle_fl.mpc.data_utils.aby3 as aby3
import mpc_data_utils as mdu
import paddle.fluid as fluid
import paddle.fluid.core as core
......@@ -190,7 +190,7 @@ class TestConv2dOp(OpTest):
'dilation': self.dilations
}
share = lambda x: np.array([x * 65536/3] * 2).astype('int64')
share = lambda x: np.array([x * mdu.mpc_one_share] * 2).astype('int64')
input = np.random.random(self.input_size)
filter = np.random.uniform(-1, 1, self.filter_size)
......@@ -385,7 +385,7 @@ class TestConv2dOp_v2(OpTest):
'dilation': self.dilations
}
share = lambda x: np.array([x * 65536/3] * 2).astype('int64')
share = lambda x: np.array([x * mdu.mpc_one_share] * 2).astype('int64')
input = np.random.random(self.input_size)
filter = np.random.uniform(-1, 1, self.filter_size)
......
......@@ -20,6 +20,7 @@ import unittest
import numpy as np
import paddle.fluid as fluid
import paddle_fl.mpc as pfl_mpc
import mpc_data_utils as mdu
import test_op_base
......@@ -92,7 +93,7 @@ class TestOpPrecisionRecall(test_op_base.TestOpBase):
self.threshold = np.random.random()
preds, labels = [], []
self.exp_res = (0, [0] * 3)
share = lambda x: np.array([x * 65536/3] * 2).astype('int64').reshape(
share = lambda x: np.array([x * mdu.mpc_one_share] * 2).astype('int64').reshape(
[2] + self.input_size)
for _ in range(n):
......
......@@ -54,15 +54,15 @@ class TestOpPool2d(test_op_base.TestOpBase):
def test_pool2d(self):
data_1 = np.array(
[[[[1, 2, 3, 4, 0, 100],
[5, 6, 7, 8, 0, 100],
[[[[1, 2, 3, 4, 0, 100],
[5, 6, 7, 8, 0, 100],
[9, 10, 11, 12, 0, 200],
[13, 14, 15, 16, 0, 200]]]]).astype('float32')
expected_out = np.array(
[[[[6, 8, 100],
[[[[6, 8, 100],
[14, 16, 200]]]]).astype('float32')
print("input data_1: {} \n".format(data_1))
# print("input data_1: {} \n".format(data_1))
data_1_shares = aby3.make_shares(data_1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册