未验证 提交 84adcdd0 编写于 作者: Q Qinghe JING 提交者: GitHub

Merge pull request #64 from kaih70/master

sigmoid op added, some bug fixed
cmake_minimum_required(VERSION 3.13)
cmake_minimum_required(VERSION 3.15)
project(PaddleEncrypted)
......
......@@ -172,7 +172,7 @@ if(WIN32 OR APPLE OR NOT WITH_GPU OR ON_INFER)
endif()
if(${CMAKE_VERSION} VERSION_GREATER "3.5.2")
set(SHALLOW_CLONE "GIT_SHALLOW TRUE") # adds --depth=1 arg to git clone of External_Projects
set(SHALLOW_CLONE GIT_SHALLOW TRUE) # adds --depth=1 arg to git clone of External_Projects
endif()
########################### include third_party according to flags ###############################
......
......@@ -138,6 +138,16 @@ public:
op_->relu(out_);
}
void sigmoid(const Tensor *op, Tensor *out) override {
auto op_tuple = from_tensor(op);
auto out_tuple = from_tensor(out);
auto op_ = std::get<0>(op_tuple).get();
auto out_ = std::get<0>(out_tuple).get();
op_->sigmoid(out_);
}
void softmax(const Tensor *op, Tensor *out) override {
auto op_tuple = from_tensor(op);
auto out_tuple = from_tensor(out);
......
......@@ -42,6 +42,8 @@ public:
virtual void relu(const Tensor *op, Tensor *out) = 0;
virtual void sigmoid(const Tensor *op, Tensor *out) = 0;
virtual void softmax(const Tensor *op, Tensor *out) = 0;
virtual void gt(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0;
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_sigmoid_cross_entropy_with_logits_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
const int kIgnoreIndex = -100;
class MpcSigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto labels_dims = ctx->GetInputDim("Label");
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, labels_dims.size(),
"Input(X) and Input(Label) shall have the same rank.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank),
framework::slice_ddim(labels_dims, 0, rank),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
}
ctx->ShareDim("X", "Out");
ctx->ShareLoD("X", "Out");
}
};
class MpcSigmoidCrossEntropyWithLogitsGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shoudl be not null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto labels_dims = ctx->GetInputDim("Label");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
int rank = x_dims.size();
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank),
framework::slice_ddim(labels_dims, 0, rank),
"Input(X) and Input(Label) shall have the same shape.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank),
framework::slice_ddim(dout_dims, 0, rank),
"Input(X) and Input(Out@Grad) shall have the same shape.");
}
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
};
class MpcSigmoidCrossEntropyWithLogitsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D, "
"where N is the batch size and D is the number of classes. "
"This input is a tensor of logits computed by the previous "
" operator. Logits are unscaled log probabilities given as "
"log(p/(1-p)).");
AddInput("Label",
"(Tensor, default Tensor<float>), a 2-D tensor of the same type "
"and shape as X. This input is a tensor of probabalistic labels "
"for each logit");
AddOutput("Out",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D "
" of elementwise logistic losses.");
AddComment(R"DOC(
MpcSigmoidCrossEntropyWithLogits Operator.
)DOC");
}
};
template <typename T>
class MpcSigmoidCrossEntropyWithLogitsGradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
retv->SetType("mpc_sigmoid_cross_entropy_with_logits_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Label", this->Input("Label"));
retv->SetInput("Out", this->Output("Out"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
return retv;
}
};
DECLARE_INPLACE_OP_INFERER(MpcSigmoidCrossEntropyWithLogitsInplaceInferer,
{"X", "Out"});
DECLARE_INPLACE_OP_INFERER(MpcSigmoidCrossEntropyWithLogitsGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
mpc_sigmoid_cross_entropy_with_logits, ops::MpcSigmoidCrossEntropyWithLogitsOp,
ops::MpcSigmoidCrossEntropyWithLogitsOpMaker,
ops::MpcSigmoidCrossEntropyWithLogitsGradOpMaker<paddle::framework::OpDesc>,
ops::MpcSigmoidCrossEntropyWithLogitsInplaceInferer);
REGISTER_OPERATOR(mpc_sigmoid_cross_entropy_with_logits_grad,
ops::MpcSigmoidCrossEntropyWithLogitsGradOp,
ops::MpcSigmoidCrossEntropyWithLogitsGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
mpc_sigmoid_cross_entropy_with_logits,
ops::MpcSigmoidCrossEntropyWithLogitsKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
mpc_sigmoid_cross_entropy_with_logits_grad,
ops::MpcSigmoidCrossEntropyWithLogitsGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "mpc_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
// Out = sigmoid(x) : prediction of x.
// todo: Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X)))
template <typename DeviceContext, typename T>
class MpcSigmoidCrossEntropyWithLogitsKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<Tensor>("X");
auto *out_t = ctx.Output<Tensor>("Out");
out_t->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sigmoid(in_x_t, out_t);
}
};
// dX = sigmoid(X) - labels
template <typename DeviceContext, typename T>
class MpcSigmoidCrossEntropyWithLogitsGradKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_label_t = ctx.Input<Tensor>("Label");
auto *in_sigmoid_t = ctx.Input<Tensor>("Out");
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(in_sigmoid_t, in_label_t, dx);
}
};
} // namespace operators
} // namespace paddle
# coding=utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MNIST Demo
"""
import sys
sys.path.append('../')
import env_set
import numpy as np
import time
import paddle
import paddle.fluid as fluid
import paddle_fl.mpc as pfl_mpc
import paddle_fl.mpc.data_utils.aby3 as aby3
import prepare_data
role, server, port = env_set.TestOptions().values()
# modify host(localhost).
pfl_mpc.init("aby3", int(role), "localhost", server, int(port))
role = int(role)
# data preprocessing
BATCH_SIZE = 128
epoch_num = 2
# network
x = pfl_mpc.data(name='x', shape=[BATCH_SIZE, 784], dtype='int64')
y = pfl_mpc.data(name='y', shape=[BATCH_SIZE, 1], dtype='int64')
y_pre = pfl_mpc.layers.fc(input=x, size=1)
cost = pfl_mpc.layers.sigmoid_cross_entropy_with_logits(y_pre, y)
infer_program = fluid.default_main_program().clone(for_test=False)
avg_loss = pfl_mpc.layers.mean(cost)
optimizer = pfl_mpc.optimizer.SGD(learning_rate=0.001)
optimizer.minimize(avg_loss)
# train_reader
feature_reader = aby3.load_aby3_shares("/tmp/mnist2_feature", id=role, shape=(784,))
label_reader = aby3.load_aby3_shares("/tmp/mnist2_label", id=role, shape=(1,))
batch_feature = aby3.batch(feature_reader, BATCH_SIZE, drop_last=True)
batch_label = aby3.batch(label_reader, BATCH_SIZE, drop_last=True)
# test_reader
test_feature_reader = aby3.load_aby3_shares("/tmp/mnist2_test_feature", id=role, shape=(784,))
test_label_reader = aby3.load_aby3_shares("/tmp/mnist2_test_label", id=role, shape=(1,))
test_batch_feature = aby3.batch(test_feature_reader, BATCH_SIZE, drop_last=True)
test_batch_label = aby3.batch(test_label_reader, BATCH_SIZE, drop_last=True)
place = fluid.CPUPlace()
# async data loader
loader = fluid.io.DataLoader.from_generator(feed_list=[x, y], capacity=BATCH_SIZE)
batch_sample = paddle.reader.compose(batch_feature, batch_label)
loader.set_batch_generator(batch_sample, places=place)
test_loader = fluid.io.DataLoader.from_generator(feed_list=[x, y], capacity=BATCH_SIZE)
test_batch_sample = paddle.reader.compose(test_batch_feature, test_batch_label)
test_loader.set_batch_generator(test_batch_sample, places=place)
# loss file
loss_file = "/tmp/mnist_output_loss.part{}".format(role)
# train
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
start_time = time.time()
step = 0
for epoch_id in range(epoch_num):
# feed data via loader
for sample in loader():
exe.run(feed=sample)
if step % 50 == 0:
print('Epoch={}, Step={}'.format(epoch_id, step))
step += 1
end_time = time.time()
print('Mpc Training of Epoch={} Batch_size={}, cost time in seconds:{}'
.format(epoch_num, BATCH_SIZE, (end_time - start_time)))
# prediction
prediction_file = "/tmp/mnist_output_prediction.part{}".format(role)
for sample in test_loader():
prediction = exe.run(program=infer_program, feed=sample, fetch_list=[cost])
with open(prediction_file, 'ab') as f:
f.write(np.array(prediction).tostring())
# decrypt
#if 0 == role:
# prepare_data.decrypt_data_to_file("/tmp/mnist_output_prediction", (BATCH_SIZE,), "mpc_label")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Prepare of MNIST data for MPC usage
"""
import sys
import numpy as np
import paddle
import six
from paddle_fl.mpc.data_utils import aby3
sample_reader = paddle.dataset.mnist.train()
test_reader = paddle.dataset.mnist.test()
def generate_encrypted_data():
"""
generate encrypted samples
"""
def encrypted_mnist_features():
"""
feature reader
"""
for instance in sample_reader():
yield aby3.make_shares(instance[0])
def encrypted_mnist_labels():
"""
label reader
"""
for instance in sample_reader():
yield aby3.make_shares(np.array(1) if instance[1] == 0 else np.array(0))
aby3.save_aby3_shares(encrypted_mnist_features, "/tmp/mnist2_feature")
aby3.save_aby3_shares(encrypted_mnist_labels, "/tmp/mnist2_label")
def generate_encrypted_test_data():
"""
generate encrypted samples
"""
def encrypted_mnist_features():
"""
feature reader
"""
for instance in test_reader():
yield aby3.make_shares(instance[0])
def encrypted_mnist_labels():
"""
label reader
"""
for instance in test_reader():
yield aby3.make_shares(np.array(1) if instance[1] == 0 else np.array(0))
aby3.save_aby3_shares(encrypted_mnist_features, "/tmp/mnist2_test_feature")
aby3.save_aby3_shares(encrypted_mnist_labels, "/tmp/mnist2_test_label")
def load_decrypt_data(filepath, shape):
"""
load the encrypted data and reconstruct
"""
part_readers = []
for id in six.moves.range(3):
part_readers.append(aby3.load_aby3_shares(filepath, id=id, shape=shape))
aby3_share_reader = paddle.reader.compose(part_readers[0], part_readers[1], part_readers[2])
for instance in aby3_share_reader():
p = aby3.reconstruct(np.array(instance))
print(p)
def decrypt_data_to_file(filepath, shape, decrypted_filepath):
"""
load the encrypted data and reconstruct
"""
part_readers = []
for id in six.moves.range(3):
part_readers.append(aby3.load_aby3_shares(filepath, id=id, shape=shape))
aby3_share_reader = paddle.reader.compose(part_readers[0], part_readers[1], part_readers[2])
for instance in aby3_share_reader():
p = aby3.reconstruct(np.array(instance))
with open(decrypted_filepath, 'a+') as f:
for i in p:
f.write(str(i) + '\n')
# generate_encrypted_data()
# generate_encrypted_test_data()
......@@ -100,4 +100,4 @@ import prepare_data
print("uci_loss:")
prepare_data.load_decrypt_data("/tmp/uci_loss", (1, ))
print("prediction:")
prepare_data.load_decrypt_data("/tmp/uci_prediction", (1, ))
prepare_data.load_decrypt_data("/tmp/uci_prediction", (BATCH_SIZE, ))
......@@ -22,7 +22,13 @@ import numpy
from ..framework import MpcVariable
from ..mpc_layer_helper import MpcLayerHelper
__all__ = ['fc', 'relu', 'softmax']
__all__ = [
'fc',
'relu',
'softmax',
'sigmoid_cross_entropy_with_logits',
]
# add softmax, relu
......@@ -124,10 +130,10 @@ def fc(input,
num_flatten_dims = len(input_shape) - 1
param_num_flatten_dims = num_flatten_dims
else:
param_num_flatten_dims = num_flatten_dims + 1 # The first dimension '2' of input is share number.
param_num_flatten_dims = num_flatten_dims + 1 # The first dimension '2' of input is share number.
param_shape = [
reduce(lambda a, b: a * b, input_shape[param_num_flatten_dims:], 1)
] + [size]
reduce(lambda a, b: a * b, input_shape[param_num_flatten_dims:], 1)
] + [size]
w = helper.create_mpc_parameter(
attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False)
tmp = helper.create_mpc_variable_for_type_inference(dtype)
......@@ -150,8 +156,7 @@ def fc(input,
outputs={"Out": pre_bias},
attrs={"use_mkldnn": False})
# add bias
pre_activation = helper.append_mpc_bias_op(
pre_bias, dim_start=num_flatten_dims)
pre_activation = helper.append_mpc_bias_op(pre_bias, dim_start=num_flatten_dims)
# add activation
return helper.append_mpc_activation(pre_activation)
......@@ -220,5 +225,39 @@ def relu(input, name=None):
helper = MpcLayerHelper('relu', **locals())
dtype = helper.input_dtype(input_param_name='input')
out = helper.create_mpc_variable_for_type_inference(dtype)
helper.append_op(type="mpc_relu", inputs={"X": input}, outputs={"Y": out})
helper.append_op(
type="mpc_relu", inputs={"X": input}, outputs={"Y": out})
return out
def sigmoid_cross_entropy_with_logits(x,
label,
name=None):
"""
sigmoid_cross_entropy_with_logits
forward: out = sigmoid(x). todo: add cross_entropy
backward: dx = sigmoid(x) - label
Args:
x(MpcVariable): input
label(MpcVariable): labels
name(str|None): The default value is None. Normally there is
no need for user to set this property. For more information,
please refer to :ref:`api_guide_Name`
Returns:
out(MpcVariable): out = sigmoid(x)
"""
helper = MpcLayerHelper("sigmoid_cross_entropy_with_logits", **locals())
if name is None:
out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype)
else:
out = helper.create_mpc_variable(
name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type="mpc_sigmoid_cross_entropy_with_logits",
inputs={"X": x,
"Label": label},
outputs={"Out": out})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册