diff --git a/CMakeLists.txt b/CMakeLists.txt index f03ba254085765ef0a67cf28d856d1f87aec29cf..1a0e597b5c37af3cceb8693986cad7f6fac3ba8e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.13) +cmake_minimum_required(VERSION 3.15) project(PaddleEncrypted) diff --git a/README.md b/README.md index 38a72106477323daeed5356f293c82cac12c500f..2a1fc63eb85bc4aac97b1178a0448700f4554580 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ In PaddleFL, horizontal and vertical federated learning strategies will be imple - **Vertical Federated Learning**: Logistic Regression with PrivC, Neural Network with third-party PrivC [5] -- **Horizontal Federated Learning**: Federated Averaging [2], Differential Privacy [6] +- **Horizontal Federated Learning**: Federated Averaging [2], Differential Privacy [6], Secure Aggregation #### Training Strategy @@ -38,7 +38,7 @@ In PaddleFL, components for defining a federated learning task and training a fe #### Compile Time -- **FL-Strategy**: a user can define federated learning strategies with FL-Strategy such as Fed-Avg[1] +- **FL-Strategy**: a user can define federated learning strategies with FL-Strategy such as Fed-Avg[2] - **User-Defined-Program**: PaddlePaddle's program that defines the machine learning model structure and training strategies such as multi-task learning. diff --git a/README_cn.md b/README_cn.md index aa0bf2c55dca0d4b909a542800bbc8d2f9279a63..2ab1330c106cc9634bbbf0d0d9902e691485463f 100644 --- a/README_cn.md +++ b/README_cn.md @@ -16,7 +16,7 @@ PaddleFL是一个基于PaddlePaddle的开源联邦学习框架。研究人员可 - **纵向联邦学习**: 带privc的逻辑回归,带第三方privc的神经网络[5] -- **横向联邦学习**: 联邦平均 [2],差分隐私 [6] +- **横向联邦学习**: 联邦平均 [2],差分隐私 [6],安全聚合 #### 训练策略 @@ -35,7 +35,7 @@ PaddleFL是一个基于PaddlePaddle的开源联邦学习框架。研究人员可 #### 编译时 -- **FL-Strategy**: 用户可以使用FL-Strategy定义联邦学习策略,例如Fed-Avg[1]。 +- **FL-Strategy**: 用户可以使用FL-Strategy定义联邦学习策略,例如Fed-Avg[2]。 - **User-Defined-Program**: PaddlePaddle的程序定义了机器学习模型结构和训练策略,如多任务学习。 diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 3104a991d4fa9f9495fde309076728c9c2987a62..b2ccdf431e0e4a990e9f0f1d524539c3d54385b2 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -172,7 +172,7 @@ if(WIN32 OR APPLE OR NOT WITH_GPU OR ON_INFER) endif() if(${CMAKE_VERSION} VERSION_GREATER "3.5.2") - set(SHALLOW_CLONE "GIT_SHALLOW TRUE") # adds --depth=1 arg to git clone of External_Projects + set(SHALLOW_CLONE GIT_SHALLOW TRUE) # adds --depth=1 arg to git clone of External_Projects endif() ########################### include third_party according to flags ############################### diff --git a/core/paddlefl_mpc/mpc_protocol/aby3_operators.h b/core/paddlefl_mpc/mpc_protocol/aby3_operators.h index db895741d0b5ac22fe239f20284f778e5aa468ea..9981b2a0087d6e7c2915f88d6cb9ad32b9e234a0 100644 --- a/core/paddlefl_mpc/mpc_protocol/aby3_operators.h +++ b/core/paddlefl_mpc/mpc_protocol/aby3_operators.h @@ -138,6 +138,16 @@ public: op_->relu(out_); } + void sigmoid(const Tensor *op, Tensor *out) override { + auto op_tuple = from_tensor(op); + auto out_tuple = from_tensor(out); + + auto op_ = std::get<0>(op_tuple).get(); + auto out_ = std::get<0>(out_tuple).get(); + + op_->sigmoid(out_); + } + void softmax(const Tensor *op, Tensor *out) override { auto op_tuple = from_tensor(op); auto out_tuple = from_tensor(out); diff --git a/core/paddlefl_mpc/mpc_protocol/mpc_operators.h b/core/paddlefl_mpc/mpc_protocol/mpc_operators.h index c0683f17538e516ee09e163fd37bd6dcd95ee810..8fc6977512ea38b3b21d7b301fc91d1595c30c43 100644 --- a/core/paddlefl_mpc/mpc_protocol/mpc_operators.h +++ b/core/paddlefl_mpc/mpc_protocol/mpc_operators.h @@ -42,6 +42,8 @@ public: virtual void relu(const Tensor *op, Tensor *out) = 0; + virtual void sigmoid(const Tensor *op, Tensor *out) = 0; + virtual void softmax(const Tensor *op, Tensor *out) = 0; virtual void gt(const Tensor *lhs, const Tensor *rhs, Tensor *out) = 0; diff --git a/core/paddlefl_mpc/operators/mpc_sigmoid_cross_entropy_with_logits_op.cc b/core/paddlefl_mpc/operators/mpc_sigmoid_cross_entropy_with_logits_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..09035196947d2cfa20ae691102ebf60426f421e9 --- /dev/null +++ b/core/paddlefl_mpc/operators/mpc_sigmoid_cross_entropy_with_logits_op.cc @@ -0,0 +1,161 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "mpc_sigmoid_cross_entropy_with_logits_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; +const int kIgnoreIndex = -100; + +class MpcSigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto labels_dims = ctx->GetInputDim("Label"); + + int rank = x_dims.size(); + PADDLE_ENFORCE_EQ(rank, labels_dims.size(), + "Input(X) and Input(Label) shall have the same rank."); + bool check = true; + if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || + framework::product(labels_dims) <= 0)) { + check = false; + } + + if (check) { + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank), + framework::slice_ddim(labels_dims, 0, rank), + "Input(X) and Input(Label) shall have the same shape " + "except the last dimension."); + } + + ctx->ShareDim("X", "Out"); + ctx->ShareLoD("X", "Out"); + } +}; + +class MpcSigmoidCrossEntropyWithLogitsGradOp : public framework::OperatorWithKernel { +public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) shoudl be not null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto labels_dims = ctx->GetInputDim("Label"); + auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); + + int rank = x_dims.size(); + bool check = true; + if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || + framework::product(labels_dims) <= 0)) { + check = false; + } + + if (check) { + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank), + framework::slice_ddim(labels_dims, 0, rank), + "Input(X) and Input(Label) shall have the same shape."); + + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank), + framework::slice_ddim(dout_dims, 0, rank), + "Input(X) and Input(Out@Grad) shall have the same shape."); + } + + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } +}; + +class MpcSigmoidCrossEntropyWithLogitsOpMaker : public framework::OpProtoAndCheckerMaker { +public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), a 2-D tensor with shape N x D, " + "where N is the batch size and D is the number of classes. " + "This input is a tensor of logits computed by the previous " + " operator. Logits are unscaled log probabilities given as " + "log(p/(1-p))."); + AddInput("Label", + "(Tensor, default Tensor), a 2-D tensor of the same type " + "and shape as X. This input is a tensor of probabalistic labels " + "for each logit"); + AddOutput("Out", + "(Tensor, default Tensor), a 2-D tensor with shape N x D " + " of elementwise logistic losses."); + AddComment(R"DOC( +MpcSigmoidCrossEntropyWithLogits Operator. +)DOC"); + } +}; + +template +class MpcSigmoidCrossEntropyWithLogitsGradOpMaker : public framework::SingleGradOpDescMaker { +public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + +protected: + std::unique_ptr Apply() const override { + std::unique_ptr retv(new T()); + retv->SetType("mpc_sigmoid_cross_entropy_with_logits_grad"); + retv->SetInput("X", this->Input("X")); + retv->SetInput("Label", this->Input("Label")); + retv->SetInput("Out", this->Output("Out")); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + retv->SetAttrMap(this->Attrs()); + return retv; + } +}; + +DECLARE_INPLACE_OP_INFERER(MpcSigmoidCrossEntropyWithLogitsInplaceInferer, + {"X", "Out"}); +DECLARE_INPLACE_OP_INFERER(MpcSigmoidCrossEntropyWithLogitsGradInplaceInferer, + {framework::GradVarName("Out"), + framework::GradVarName("X")}); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + mpc_sigmoid_cross_entropy_with_logits, ops::MpcSigmoidCrossEntropyWithLogitsOp, + ops::MpcSigmoidCrossEntropyWithLogitsOpMaker, + ops::MpcSigmoidCrossEntropyWithLogitsGradOpMaker, + ops::MpcSigmoidCrossEntropyWithLogitsInplaceInferer); +REGISTER_OPERATOR(mpc_sigmoid_cross_entropy_with_logits_grad, + ops::MpcSigmoidCrossEntropyWithLogitsGradOp, + ops::MpcSigmoidCrossEntropyWithLogitsGradInplaceInferer); +REGISTER_OP_CPU_KERNEL( + mpc_sigmoid_cross_entropy_with_logits, + ops::MpcSigmoidCrossEntropyWithLogitsKernel); +REGISTER_OP_CPU_KERNEL( + mpc_sigmoid_cross_entropy_with_logits_grad, + ops::MpcSigmoidCrossEntropyWithLogitsGradKernel); diff --git a/core/paddlefl_mpc/operators/mpc_sigmoid_cross_entropy_with_logits_op.h b/core/paddlefl_mpc/operators/mpc_sigmoid_cross_entropy_with_logits_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8907f0bbcc7440b685d4f7191d56288753dfb398 --- /dev/null +++ b/core/paddlefl_mpc/operators/mpc_sigmoid_cross_entropy_with_logits_op.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "mpc_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +// Out = sigmoid(x) : prediction of x. +// todo: Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X))) +template +class MpcSigmoidCrossEntropyWithLogitsKernel : public MpcOpKernel { +public: + void ComputeImpl(const framework::ExecutionContext &ctx) const override { + auto *in_x_t = ctx.Input("X"); + auto *out_t = ctx.Output("Out"); + out_t->mutable_data(ctx.GetPlace()); + + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sigmoid(in_x_t, out_t); + } +}; + +// dX = sigmoid(X) - labels +template +class MpcSigmoidCrossEntropyWithLogitsGradKernel : public MpcOpKernel { +public: + void ComputeImpl(const framework::ExecutionContext &ctx) const override { + auto *in_label_t = ctx.Input("Label"); + auto *in_sigmoid_t = ctx.Input("Out"); + auto dx = ctx.Output(framework::GradVarName("X")); + + auto dx_data = dx->mutable_data(ctx.GetPlace()); + + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(in_sigmoid_t, in_label_t, dx); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle_fl/mpc/examples/mnist_demo/mnist_demo.py b/python/paddle_fl/mpc/examples/mnist_demo/mnist_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..a1b6e5b984fa739c1c21471becab80b6314841b3 --- /dev/null +++ b/python/paddle_fl/mpc/examples/mnist_demo/mnist_demo.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +MNIST Demo +""" + +import sys +sys.path.append('../') + +import env_set +import numpy as np +import time + +import paddle +import paddle.fluid as fluid +import paddle_fl.mpc as pfl_mpc +import paddle_fl.mpc.data_utils.aby3 as aby3 +import prepare_data + +role, server, port = env_set.TestOptions().values() +# modify host(localhost). +pfl_mpc.init("aby3", int(role), "localhost", server, int(port)) +role = int(role) + +# data preprocessing +BATCH_SIZE = 128 +epoch_num = 2 + +# network +x = pfl_mpc.data(name='x', shape=[BATCH_SIZE, 784], dtype='int64') +y = pfl_mpc.data(name='y', shape=[BATCH_SIZE, 1], dtype='int64') + +y_pre = pfl_mpc.layers.fc(input=x, size=1) +cost = pfl_mpc.layers.sigmoid_cross_entropy_with_logits(y_pre, y) + +infer_program = fluid.default_main_program().clone(for_test=False) + +avg_loss = pfl_mpc.layers.mean(cost) +optimizer = pfl_mpc.optimizer.SGD(learning_rate=0.001) +optimizer.minimize(avg_loss) + +# train_reader +feature_reader = aby3.load_aby3_shares("/tmp/mnist2_feature", id=role, shape=(784,)) +label_reader = aby3.load_aby3_shares("/tmp/mnist2_label", id=role, shape=(1,)) +batch_feature = aby3.batch(feature_reader, BATCH_SIZE, drop_last=True) +batch_label = aby3.batch(label_reader, BATCH_SIZE, drop_last=True) + +# test_reader +test_feature_reader = aby3.load_aby3_shares("/tmp/mnist2_test_feature", id=role, shape=(784,)) +test_label_reader = aby3.load_aby3_shares("/tmp/mnist2_test_label", id=role, shape=(1,)) +test_batch_feature = aby3.batch(test_feature_reader, BATCH_SIZE, drop_last=True) +test_batch_label = aby3.batch(test_label_reader, BATCH_SIZE, drop_last=True) + +place = fluid.CPUPlace() + +# async data loader +loader = fluid.io.DataLoader.from_generator(feed_list=[x, y], capacity=BATCH_SIZE) +batch_sample = paddle.reader.compose(batch_feature, batch_label) +loader.set_batch_generator(batch_sample, places=place) + +test_loader = fluid.io.DataLoader.from_generator(feed_list=[x, y], capacity=BATCH_SIZE) +test_batch_sample = paddle.reader.compose(test_batch_feature, test_batch_label) +test_loader.set_batch_generator(test_batch_sample, places=place) + +# loss file +loss_file = "/tmp/mnist_output_loss.part{}".format(role) + +# train +exe = fluid.Executor(place) +exe.run(fluid.default_startup_program()) + +start_time = time.time() +step = 0 +for epoch_id in range(epoch_num): + # feed data via loader + for sample in loader(): + exe.run(feed=sample) + if step % 50 == 0: + print('Epoch={}, Step={}'.format(epoch_id, step)) + step += 1 + +end_time = time.time() +print('Mpc Training of Epoch={} Batch_size={}, cost time in seconds:{}' + .format(epoch_num, BATCH_SIZE, (end_time - start_time))) + +# prediction +prediction_file = "/tmp/mnist_output_prediction.part{}".format(role) +for sample in test_loader(): + prediction = exe.run(program=infer_program, feed=sample, fetch_list=[cost]) + with open(prediction_file, 'ab') as f: + f.write(np.array(prediction).tostring()) + +# decrypt +#if 0 == role: +# prepare_data.decrypt_data_to_file("/tmp/mnist_output_prediction", (BATCH_SIZE,), "mpc_label") + + diff --git a/python/paddle_fl/mpc/examples/mnist_demo/prepare_data.py b/python/paddle_fl/mpc/examples/mnist_demo/prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..13e2a820b0da7fac9538b31c0227d03a8c49993c --- /dev/null +++ b/python/paddle_fl/mpc/examples/mnist_demo/prepare_data.py @@ -0,0 +1,99 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Prepare of MNIST data for MPC usage +""" +import sys + +import numpy as np +import paddle +import six +from paddle_fl.mpc.data_utils import aby3 + +sample_reader = paddle.dataset.mnist.train() +test_reader = paddle.dataset.mnist.test() + +def generate_encrypted_data(): + """ + generate encrypted samples + """ + def encrypted_mnist_features(): + """ + feature reader + """ + for instance in sample_reader(): + yield aby3.make_shares(instance[0]) + + def encrypted_mnist_labels(): + """ + label reader + """ + for instance in sample_reader(): + yield aby3.make_shares(np.array(1) if instance[1] == 0 else np.array(0)) + + aby3.save_aby3_shares(encrypted_mnist_features, "/tmp/mnist2_feature") + aby3.save_aby3_shares(encrypted_mnist_labels, "/tmp/mnist2_label") + +def generate_encrypted_test_data(): + """ + generate encrypted samples + """ + def encrypted_mnist_features(): + """ + feature reader + """ + for instance in test_reader(): + yield aby3.make_shares(instance[0]) + + def encrypted_mnist_labels(): + """ + label reader + """ + for instance in test_reader(): + yield aby3.make_shares(np.array(1) if instance[1] == 0 else np.array(0)) + + aby3.save_aby3_shares(encrypted_mnist_features, "/tmp/mnist2_test_feature") + aby3.save_aby3_shares(encrypted_mnist_labels, "/tmp/mnist2_test_label") + +def load_decrypt_data(filepath, shape): + """ + load the encrypted data and reconstruct + """ + part_readers = [] + for id in six.moves.range(3): + part_readers.append(aby3.load_aby3_shares(filepath, id=id, shape=shape)) + aby3_share_reader = paddle.reader.compose(part_readers[0], part_readers[1], part_readers[2]) + + for instance in aby3_share_reader(): + p = aby3.reconstruct(np.array(instance)) + print(p) + +def decrypt_data_to_file(filepath, shape, decrypted_filepath): + """ + load the encrypted data and reconstruct + """ + part_readers = [] + for id in six.moves.range(3): + part_readers.append(aby3.load_aby3_shares(filepath, id=id, shape=shape)) + aby3_share_reader = paddle.reader.compose(part_readers[0], part_readers[1], part_readers[2]) + + for instance in aby3_share_reader(): + p = aby3.reconstruct(np.array(instance)) + with open(decrypted_filepath, 'a+') as f: + for i in p: + f.write(str(i) + '\n') + +# generate_encrypted_data() +# generate_encrypted_test_data() + diff --git a/python/paddle_fl/mpc/layers/ml.py b/python/paddle_fl/mpc/layers/ml.py index 0078583f2e15f739004bc976d7ee2af2d90f60c4..39cf6b90ac2e5bf8a01ae93fab305eed3d435f6d 100644 --- a/python/paddle_fl/mpc/layers/ml.py +++ b/python/paddle_fl/mpc/layers/ml.py @@ -22,7 +22,13 @@ import numpy from ..framework import MpcVariable from ..mpc_layer_helper import MpcLayerHelper -__all__ = ['fc', 'relu', 'softmax'] +__all__ = [ + 'fc', + 'relu', + 'softmax', + 'sigmoid_cross_entropy_with_logits', +] + # add softmax, relu @@ -124,10 +130,10 @@ def fc(input, num_flatten_dims = len(input_shape) - 1 param_num_flatten_dims = num_flatten_dims else: - param_num_flatten_dims = num_flatten_dims + 1 # The first dimension '2' of input is share number. + param_num_flatten_dims = num_flatten_dims + 1 # The first dimension '2' of input is share number. param_shape = [ - reduce(lambda a, b: a * b, input_shape[param_num_flatten_dims:], 1) - ] + [size] + reduce(lambda a, b: a * b, input_shape[param_num_flatten_dims:], 1) + ] + [size] w = helper.create_mpc_parameter( attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False) tmp = helper.create_mpc_variable_for_type_inference(dtype) @@ -150,8 +156,7 @@ def fc(input, outputs={"Out": pre_bias}, attrs={"use_mkldnn": False}) # add bias - pre_activation = helper.append_mpc_bias_op( - pre_bias, dim_start=num_flatten_dims) + pre_activation = helper.append_mpc_bias_op(pre_bias, dim_start=num_flatten_dims) # add activation return helper.append_mpc_activation(pre_activation) @@ -220,5 +225,39 @@ def relu(input, name=None): helper = MpcLayerHelper('relu', **locals()) dtype = helper.input_dtype(input_param_name='input') out = helper.create_mpc_variable_for_type_inference(dtype) - helper.append_op(type="mpc_relu", inputs={"X": input}, outputs={"Y": out}) + helper.append_op( + type="mpc_relu", inputs={"X": input}, outputs={"Y": out}) return out + + +def sigmoid_cross_entropy_with_logits(x, + label, + name=None): + """ + sigmoid_cross_entropy_with_logits + forward: out = sigmoid(x). todo: add cross_entropy + backward: dx = sigmoid(x) - label + Args: + x(MpcVariable): input + label(MpcVariable): labels + name(str|None): The default value is None. Normally there is + no need for user to set this property. For more information, + please refer to :ref:`api_guide_Name` + Returns: + out(MpcVariable): out = sigmoid(x) + """ + + helper = MpcLayerHelper("sigmoid_cross_entropy_with_logits", **locals()) + + if name is None: + out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype) + else: + out = helper.create_mpc_variable( + name=name, dtype=x.dtype, persistable=False) + + helper.append_op( + type="mpc_sigmoid_cross_entropy_with_logits", + inputs={"X": x, + "Label": label}, + outputs={"Out": out}) + return out