diff --git a/CMakeLists.txt b/CMakeLists.txt index aac6879b85e184a1b7ad7b034c78d92295c7e207..6df43c8ac458b0283cf819f905ddc4d0d1978e95 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,8 +34,8 @@ execute_process(COMMAND ${PYTHON} -c "import paddle;print(paddle.version.full_ve RESULT_VARIABLE ret OUTPUT_VARIABLE paddle_version OUTPUT_STRIP_TRAILING_WHITESPACE) if (NOT ret) - if (NOT ${paddle_version} STREQUAL "1.6.3") - message(FATAL_ERROR "Paddle installation of 1.6.3 is required but ${paddle_version} is found") + if (NOT ${paddle_version} STREQUAL "1.8.0") + message(FATAL_ERROR "Paddle installation of 1.8.0 is required but ${paddle_version} is found") endif() else() message(FATAL_ERROR "Could not get paddle version.") diff --git a/cmake/external/gtest.cmake b/cmake/external/gtest.cmake index 0e539b4487a4e3651c87aa5a9f8d826af9ff02a4..c116a9959055d0727eb2391a6c3067d30c3dd2bd 100644 --- a/cmake/external/gtest.cmake +++ b/cmake/external/gtest.cmake @@ -1,10 +1,10 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/cmake/external/hiredis.cmake b/cmake/external/hiredis.cmake index 0261ca56753fac4b89cf423242506759da5e1089..ee536a269729c4c79d79afb2dae35c6555b2923e 100644 --- a/cmake/external/hiredis.cmake +++ b/cmake/external/hiredis.cmake @@ -1,10 +1,10 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/cmake/external/openssl.cmake b/cmake/external/openssl.cmake index 680be6a4d938dfa0355942181145e51f38a12da4..c17e3142c298d631bcbe66e0cf351c0363ea59f0 100644 --- a/cmake/external/openssl.cmake +++ b/cmake/external/openssl.cmake @@ -1,10 +1,10 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/cmake/external/pybind11.cmake b/cmake/external/pybind11.cmake index a7a6a771abeb430da5d2770fa4bb830b23c3b1d6..7449085cb0156cb66a603e928e9104c59768f9f6 100644 --- a/cmake/external/pybind11.cmake +++ b/cmake/external/pybind11.cmake @@ -1,10 +1,10 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 74e3dd8eaf80986c145c523ec3e3a96c44dc987d..839e560ddb2e89068b80270ba71f52ce79858dcb 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -1,10 +1,10 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index b2ccdf431e0e4a990e9f0f1d524539c3d54385b2..f33b771e8597c3477eb522bd5b7c163bb49c37f4 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -1,10 +1,10 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, diff --git a/core/paddlefl_mpc/operators/mpc_compare_op.cc b/core/paddlefl_mpc/operators/mpc_compare_op.cc index 66460fab4aecc1b9075d0daa2500211fc098b6cd..8830fc0d8d086c6eb7f4b7fd305fb032f383ab8e 100644 --- a/core/paddlefl_mpc/operators/mpc_compare_op.cc +++ b/core/paddlefl_mpc/operators/mpc_compare_op.cc @@ -1,19 +1,19 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ -#include "mpc_compare_op.h" #include "paddle/fluid/framework/op_registry.h" +#include "mpc_compare_op.h" namespace paddle { namespace operators { @@ -23,85 +23,73 @@ using Tensor = framework::Tensor; class MpcCompareOp : public framework::OperatorWithKernel { public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::NotFound( - "Input(X) of MpcCompareOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, - platform::errors::NotFound( - "Input(Y) of MpcCompareOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::NotFound( - "Output(Out) of MpcCompareOp should not be null.")); - - auto dim_x = ctx->GetInputDim("X"); - auto dim_y = ctx->GetInputDim("Y"); - PADDLE_ENFORCE_GE(dim_x.size(), dim_y.size(), - "The size of dim_y should not be greater than dim_x's."); - - ctx->ShareDim("Y", /*->*/ "Out"); - ctx->ShareLoD("Y", /*->*/ "Out"); + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::NotFound("Input(X) of MpcCompareOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Y"), true, + platform::errors::NotFound("Input(Y) of MpcCompareOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + platform::errors::NotFound("Output(Out) of MpcCompareOp should not be null.")); + + auto dim_x = ctx->GetInputDim("X"); + auto dim_y = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_GE(dim_x.size(), dim_y.size(), + "The size of dim_y should not be greater than dim_x's."); + + ctx->ShareDim("Y", /*->*/ "Out"); + ctx->ShareLoD("Y", /*->*/ "Out"); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } - framework::OpKernelType - GetExpectedKernelType(const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); - } }; class MpcCompareOpMaker : public framework::OpProtoAndCheckerMaker { public: - void Make() override { - AddInput("X", "(Tensor), The first input tensor of MpcCompareOp."); - AddInput("Y", "(Tensor), The second input tensor of MpcCompareOp."); - AddOutput("Out", "(Tensor), The output tensor of MpcCompareOp."); - AddComment(R"DOC( + void Make() override { + AddInput("X", "(Tensor), The first input tensor of MpcCompareOp."); + AddInput("Y", "(Tensor), The second input tensor of MpcCompareOp."); + AddOutput("Out", "(Tensor), The output tensor of MpcCompareOp."); + AddComment(R"DOC( MPC Compare Operator. )DOC"); - } + } }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_than, ops::MpcCompareOp, - ops::MpcCompareOpMaker); -REGISTER_OP_CPU_KERNEL( - mpc_greater_than, - ops::MpcCompareOpKernel); - -REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_equal, ops::MpcCompareOp, - ops::MpcCompareOpMaker); -REGISTER_OP_CPU_KERNEL( - mpc_greater_equal, - ops::MpcCompareOpKernel); - -REGISTER_OP_WITHOUT_GRADIENT(mpc_less_than, ops::MpcCompareOp, - ops::MpcCompareOpMaker); -REGISTER_OP_CPU_KERNEL( - mpc_less_than, ops::MpcCompareOpKernel); - -REGISTER_OP_WITHOUT_GRADIENT(mpc_less_equal, ops::MpcCompareOp, - ops::MpcCompareOpMaker); -REGISTER_OP_CPU_KERNEL( - mpc_less_equal, ops::MpcCompareOpKernel); - -REGISTER_OP_WITHOUT_GRADIENT(mpc_equal, ops::MpcCompareOp, - ops::MpcCompareOpMaker); -REGISTER_OP_CPU_KERNEL( - mpc_equal, ops::MpcCompareOpKernel); - -REGISTER_OP_WITHOUT_GRADIENT(mpc_not_equal, ops::MpcCompareOp, - ops::MpcCompareOpMaker); -REGISTER_OP_CPU_KERNEL( - mpc_not_equal, ops::MpcCompareOpKernel); +REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_than, ops::MpcCompareOp, ops::MpcCompareOpMaker); +REGISTER_OP_CPU_KERNEL(mpc_greater_than, + ops::MpcCompareOpKernel); + +REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker); +REGISTER_OP_CPU_KERNEL(mpc_greater_equal, + ops::MpcCompareOpKernel); + +REGISTER_OP_WITHOUT_GRADIENT(mpc_less_than, ops::MpcCompareOp, ops::MpcCompareOpMaker); +REGISTER_OP_CPU_KERNEL(mpc_less_than, + ops::MpcCompareOpKernel); + +REGISTER_OP_WITHOUT_GRADIENT(mpc_less_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker); +REGISTER_OP_CPU_KERNEL(mpc_less_equal, + ops::MpcCompareOpKernel); + +REGISTER_OP_WITHOUT_GRADIENT(mpc_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker); +REGISTER_OP_CPU_KERNEL(mpc_equal, + ops::MpcCompareOpKernel); + +REGISTER_OP_WITHOUT_GRADIENT(mpc_not_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker); +REGISTER_OP_CPU_KERNEL(mpc_not_equal, + ops::MpcCompareOpKernel); + diff --git a/core/paddlefl_mpc/operators/mpc_compare_op.h b/core/paddlefl_mpc/operators/mpc_compare_op.h index 0b3adbad09cf061c7d95c0cd2873351567e1e13b..88987eb0d77318d3c60a0aa4d14bc13bbf59e49e 100644 --- a/core/paddlefl_mpc/operators/mpc_compare_op.h +++ b/core/paddlefl_mpc/operators/mpc_compare_op.h @@ -1,22 +1,19 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License.uage governing permissions and +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once #include "mpc_op.h" -#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h" -#include -#include namespace paddle { namespace operators { @@ -24,58 +21,52 @@ namespace operators { using Tensor = framework::Tensor; struct MpcGreaterThanFunctor { - void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->gt( - in_x_t, in_y_t, out_t); - } + void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->gt(in_x_t, in_y_t, out_t); + } }; struct MpcGreaterEqualFunctor { - void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->geq( - in_x_t, in_y_t, out_t); - } + void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->geq(in_x_t, in_y_t, out_t); + } }; struct MpcLessThanFunctor { - void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->lt( - in_x_t, in_y_t, out_t); - } + void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->lt(in_x_t, in_y_t, out_t); + } }; struct MpcLessEqualFunctor { - void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->leq( - in_x_t, in_y_t, out_t); - } + void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->leq(in_x_t, in_y_t, out_t); + } }; struct MpcEqualFunctor { - void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->eq( - in_x_t, in_y_t, out_t); - } + void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->eq(in_x_t, in_y_t, out_t); + } }; struct MpcNotEqualFunctor { - void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neq( - in_x_t, in_y_t, out_t); - } + void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neq(in_x_t, in_y_t, out_t); + } }; template class MpcCompareOpKernel : public MpcOpKernel { public: - void ComputeImpl(const framework::ExecutionContext &ctx) const override { - auto *in_x_t = ctx.Input("X"); - auto *in_y_t = ctx.Input("Y"); - auto *out_t = ctx.Output("Out"); + void ComputeImpl(const framework::ExecutionContext &ctx) const override{ + auto *in_x_t = ctx.Input("X"); + auto *in_y_t = ctx.Input("Y"); + auto *out_t = ctx.Output("Out"); - auto out = out_t->mutable_data(ctx.GetPlace()); - Functor().Run(in_x_t, in_y_t, out_t); - } + auto out = out_t->mutable_data(ctx.GetPlace()); + Functor().Run(in_x_t, in_y_t, out_t); + } }; -} // namespace operators -} // namespace paddl +} // namespace operators +} // namespace paddl diff --git a/core/paddlefl_mpc/operators/mpc_elementwise_add_op.cc b/core/paddlefl_mpc/operators/mpc_elementwise_add_op.cc index 6a3a32d313b56d868ae97d8aefe0a919ed486da0..476eab7db7486082399c6ef2bf951eece005185f 100644 --- a/core/paddlefl_mpc/operators/mpc_elementwise_add_op.cc +++ b/core/paddlefl_mpc/operators/mpc_elementwise_add_op.cc @@ -1,19 +1,19 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ -#include "mpc_elementwise_add_op.h" #include "paddle/fluid/framework/op_registry.h" +#include "mpc_elementwise_add_op.h" namespace paddle { namespace operators { @@ -22,111 +22,105 @@ using Tensor = framework::Tensor; class MpcElementwiseAddOp : public framework::OperatorWithKernel { public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("X"), true, - platform::errors::NotFound( - "Input(X) of MpcElementwiseAddOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Y"), true, - platform::errors::NotFound( - "Input(Y) of MpcElementwiseAddOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("Out"), true, - platform::errors::NotFound( - "Output(Out) of MpcElementwiseAddOp should not be null.")); - PADDLE_ENFORCE_GE( - ctx->GetInputDim("X").size(), ctx->GetInputDim("Y").size(), - platform::errors::InvalidArgument( - "The dimensions of X should be greater than the dimensions of Y. " - "But received the dimensions of X is [%s], the dimensions of Y is " - "[%s]", + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::NotFound("Input(X) of MpcElementwiseAddOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Y"), true, + platform::errors::NotFound("Input(Y) of MpcElementwiseAddOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + platform::errors::NotFound("Output(Out) of MpcElementwiseAddOp should not be null.")); + PADDLE_ENFORCE_GE( + ctx->GetInputDim("X").size(), ctx->GetInputDim("Y").size(), + platform::errors::InvalidArgument( + "The dimensions of X should be greater than the dimensions of Y. " + "But received the dimensions of X is [%s], the dimensions of Y is [%s]", ctx->GetInputDim("X"), ctx->GetInputDim("Y"))); - ctx->ShareDim("X", /*->*/ "Out"); - ctx->ShareLoD("X", /*->*/ "Out"); - } + ctx->ShareDim("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); + } + }; class MpcElementwiseAddOpMaker : public framework::OpProtoAndCheckerMaker { public: - void Make() override { - AddInput("X", - "(Tensor), The first input tensor of mpc elementwise add op."); - AddInput("Y", - "(Tensor), The second input tensor of mpc elementwise add op."); - AddOutput("Out", "(Tensor), The output tensor of mpc elementwise add op."); - AddAttr("axis", + void Make() override { + AddInput("X", "(Tensor), The first input tensor of mpc elementwise add op."); + AddInput("Y", "(Tensor), The second input tensor of mpc elementwise add op."); + AddOutput("Out", "(Tensor), The output tensor of mpc elementwise add op."); + AddAttr("axis", "(int, default -1). If X.dimension != Y.dimension," "Y.dimension must be a subsequence of x.dimension. And axis " "is the start dimension index " "for broadcasting Y onto X. ") .SetDefault(-1) .EqualGreaterThan(-1); - AddComment(R"DOC( + AddComment(R"DOC( MPC elementwise add Operator. )DOC"); - } + } }; class MpcElementwiseAddGradOp : public framework::OperatorWithKernel { public: - using framework::OperatorWithKernel::OperatorWithKernel; - using Tensor = framework::Tensor; - - void InferShape(framework::InferShapeContext *ctx) const override { - auto out_grad_name = framework::GradVarName("Out"); - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null."); - PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null."); - PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true, + using framework::OperatorWithKernel::OperatorWithKernel; + using Tensor = framework::Tensor; + + void InferShape(framework::InferShapeContext *ctx) const override { + auto out_grad_name = framework::GradVarName("Out"); + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true, "Input(Out@GRAD) should not be null."); - auto x_grad_name = framework::GradVarName("X"); - auto y_grad_name = framework::GradVarName("Y"); - if (ctx->HasOutput(x_grad_name)) { - ctx->ShareDim("X", /*->*/ x_grad_name); - ctx->ShareLoD("X", /*->*/ x_grad_name); - } - if (ctx->HasOutput(y_grad_name)) { - ctx->ShareDim("Y", /*->*/ y_grad_name); - ctx->ShareLoD("Y", /*->*/ y_grad_name); + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) { + ctx->ShareDim("X", /*->*/ x_grad_name); + ctx->ShareLoD("X", /*->*/ x_grad_name); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->ShareDim("Y", /*->*/ y_grad_name); + ctx->ShareLoD("Y", /*->*/ y_grad_name); + } } - } + }; template -class MpcElementwiseAddOpGradMaker : public framework::SingleGradOpDescMaker { +class MpcElementwiseAddOpGradMaker : public framework::SingleGradOpMaker { public: - using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + using framework::SingleGradOpMaker::SingleGradOpMaker; protected: - std::unique_ptr Apply() const override { - std::unique_ptr retv(new T()); - retv->SetType("mpc_elementwise_add_grad"); - retv->SetInput("X", this->Input("X")); - retv->SetInput("Y", this->Input("Y")); - retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); - retv->SetAttrMap(this->Attrs()); - return retv; - } + void Apply(GradOpPtr grad) const override { + grad->SetType("mpc_elementwise_add_grad"); + grad->SetInput("X", this->Input("X")); + grad->SetInput("Y", this->Input("Y")); + grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + grad->SetAttrMap(this->Attrs()); + } }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(mpc_elementwise_add, ops::MpcElementwiseAddOp, - ops::MpcElementwiseAddOpMaker, - ops::MpcElementwiseAddOpGradMaker); +REGISTER_OPERATOR(mpc_elementwise_add, ops::MpcElementwiseAddOp, + ops::MpcElementwiseAddOpMaker, + ops::MpcElementwiseAddOpGradMaker); -REGISTER_OPERATOR(mpc_elementwise_add_grad, ops::MpcElementwiseAddGradOp); +REGISTER_OPERATOR(mpc_elementwise_add_grad, ops::MpcElementwiseAddGradOp); REGISTER_OP_CPU_KERNEL( - mpc_elementwise_add, + mpc_elementwise_add, ops::MpcElementwiseAddKernel); -REGISTER_OP_CPU_KERNEL(mpc_elementwise_add_grad, - ops::MpcElementwiseAddGradKernel< - paddle::platform::CPUDeviceContext, int64_t>); +REGISTER_OP_CPU_KERNEL( + mpc_elementwise_add_grad, + ops::MpcElementwiseAddGradKernel); diff --git a/core/paddlefl_mpc/operators/mpc_elementwise_add_op.h b/core/paddlefl_mpc/operators/mpc_elementwise_add_op.h index e25e8709f712487186707033490cf89d998c1cf4..d9345c757ec42f9dcc220566fe6c8a78dd64cc35 100644 --- a/core/paddlefl_mpc/operators/mpc_elementwise_add_op.h +++ b/core/paddlefl_mpc/operators/mpc_elementwise_add_op.h @@ -1,16 +1,16 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ // This op is different with elementwise_add of PaddlePaddle. // We only consider that the dimensions of X is equal with the dimensions of Y. @@ -18,7 +18,6 @@ #pragma once #include "mpc_op.h" #include "paddle/fluid/platform/transform.h" -#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h" namespace paddle { namespace operators { @@ -26,189 +25,187 @@ namespace operators { using Tensor = framework::Tensor; // paddle/fluid/operators/elementwise/elementwise_op_function.h -template class RowwiseTransformIterator; +template +class RowwiseTransformIterator; template class RowwiseTransformIterator - : public std::iterator { + : public std::iterator { public: - RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {} + RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {} - RowwiseTransformIterator &operator++() { - ++i_; - if (UNLIKELY(i_ == n_)) { - i_ = 0; + RowwiseTransformIterator &operator++() { + ++i_; + if (UNLIKELY(i_ == n_)) { + i_ = 0; + } + return *this; } - return *this; - } - RowwiseTransformIterator &operator+(int n) { - while (n-- > 0) { - ++i_; - if (UNLIKELY(i_ == n_)) { - i_ = 0; - } - } + RowwiseTransformIterator &operator+(int n) { + while (n-- > 0) { + ++i_; + if (UNLIKELY(i_ == n_)) { + i_ = 0; + } + } - return *this; - } + return *this; + } - bool operator==(const RowwiseTransformIterator - &rhs) const { - return (ptr_ + i_) == &(*rhs); - } + bool operator==(const RowwiseTransformIterator &rhs) const { + return (ptr_ + i_) == &(*rhs); + } - bool operator!=(const RowwiseTransformIterator - &rhs) const { - return (ptr_ + i_) != &(*rhs); - } + bool operator!=(const RowwiseTransformIterator &rhs) const { + return (ptr_ + i_) != &(*rhs); + } - const T &operator*() { return ptr_[i_]; } + const T &operator*() { return ptr_[i_]; } private: - const T *ptr_; - int i_; - int64_t n_; + const T *ptr_; + int i_; + int64_t n_; }; -template struct AddFunctor { - inline HOSTDEVICE T operator()(T x, T y) { return x + y; } +template +struct AddFunctor { + inline HOSTDEVICE T operator()(T x, T y) { return x + y; } }; struct GetMidDims { - inline HOSTDEVICE void operator()(const framework::DDim &x_dims, - const framework::DDim &y_dims, - const int axis, int *pre, int *n, - int *post) { - *pre = 1; - *n = 1; - *post = 1; - for (int i = 1; i < axis + 1; ++i) { - (*pre) *= x_dims[i]; - } + inline HOSTDEVICE void operator()(const framework::DDim &x_dims, + const framework::DDim &y_dims, const int axis, + int *pre, int *n, int *post) { + *pre = 1; + *n = 1; + *post = 1; + for (int i = 1; i < axis + 1; ++i) { + (*pre) *= x_dims[i]; + } - for (int i = 1; i < y_dims.size(); ++i) { - PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i], - "Broadcast dimension mismatch."); - (*n) *= y_dims[i]; - } + for (int i = 1; i < y_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i], + "Broadcast dimension mismatch."); + (*n) *= y_dims[i]; + } - for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { - (*post) *= x_dims[i]; + for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { + (*post) *= x_dims[i]; + } } - } }; -const size_t SHARE_NUM = 2; +const size_t SHARE_NUM = 2; template class MpcElementwiseAddKernel : public MpcOpKernel { public: - void ComputeImpl(const framework::ExecutionContext &ctx) const override { - auto *in_x_t = ctx.Input("X"); - auto *in_y_t = ctx.Input("Y"); - auto *out_t = ctx.Output("Out"); - - int axis = ctx.Attr("axis"); - - auto out = out_t->mutable_data(ctx.GetPlace()); - - if (in_x_t->dims() == in_y_t->dims()) { - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add( - in_x_t, in_y_t, out_t); - } else { - Tensor in_x_t_slice; - Tensor in_y_t_slice; - Tensor out_t_slice; - - for (size_t i = 0; i < SHARE_NUM; ++i) { - in_x_t_slice = in_x_t->Slice(i, i + 1); - in_y_t_slice = in_y_t->Slice(i, i + 1); - out_t_slice = out_t->Slice(i, i + 1); - - auto x_dims = in_x_t_slice.dims(); - auto y_dims = in_y_t_slice.dims(); - - axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); - PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), - "Axis should be in range [0, x_dims)"); - - int pre, n, post; - GetMidDims get_mid_dims; - get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); - PADDLE_ENFORCE_EQ( - post, 1, "post should be equal 1, but received post is [%s]", post); - - auto x_ = in_x_t_slice.data(); - auto y_ = in_y_t_slice.data(); - auto out_ = out_t_slice.data(); - auto nx_ = in_x_t_slice.numel(); - paddle::platform::Transform trans; - trans(ctx.template device_context(), x_, x_ + nx_, - RowwiseTransformIterator(y_, n), out_, - AddFunctor()); - } - } + void ComputeImpl(const framework::ExecutionContext &ctx) const override{ + auto *in_x_t = ctx.Input("X"); + auto *in_y_t = ctx.Input("Y"); + auto *out_t = ctx.Output("Out"); + + int axis = ctx.Attr("axis"); + + auto out = out_t->mutable_data(ctx.GetPlace()); + + if (in_x_t->dims() == in_y_t->dims()) { + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(in_x_t, in_y_t, out_t); + } else { + Tensor in_x_t_slice; + Tensor in_y_t_slice; + Tensor out_t_slice; + + for (size_t i = 0; i < SHARE_NUM; ++i) { + in_x_t_slice = in_x_t->Slice(i, i + 1); + in_y_t_slice = in_y_t->Slice(i, i + 1); + out_t_slice = out_t->Slice(i, i + 1); + + auto x_dims = in_x_t_slice.dims(); + auto y_dims = in_y_t_slice.dims(); + + axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + + int pre, n, post; + GetMidDims get_mid_dims; + get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); + PADDLE_ENFORCE_EQ(post, 1, + "post should be equal 1, but received post is [%s]", post); + + auto x_ = in_x_t_slice.data(); + auto y_ = in_y_t_slice.data(); + auto out_ = out_t_slice.data(); + auto nx_ = in_x_t_slice.numel(); + paddle::platform::Transform trans; + trans(ctx.template device_context(), x_, x_ + nx_, + RowwiseTransformIterator(y_, n), + out_, AddFunctor()); + } + } } }; template class MpcElementwiseAddGradKernel : public MpcOpKernel { public: - void ComputeImpl(const framework::ExecutionContext &ctx) const override { - auto *in_x_t = ctx.Input("X"); - auto *in_y_t = ctx.Input("Y"); - auto *dout = ctx.Input(framework::GradVarName("Out")); - auto *dx = ctx.Output(framework::GradVarName("X")); - auto *dy = ctx.Output(framework::GradVarName("Y")); - int axis = ctx.Attr("axis"); - auto dout_data = dout->data(); - - if (dx) { - auto dx_data = dx->mutable_data(ctx.GetPlace()); - for (size_t i = 0; i < dout->numel(); i++) { - dx_data[i] = dout_data[i]; - } - } - - if (dy) { - auto dy_data = dy->mutable_data(ctx.GetPlace()); - if (in_x_t->dims().size() == in_y_t->dims().size()) { - for (size_t i = 0; i < dout->numel(); i++) { - dy_data[i] = dout_data[i]; + void ComputeImpl(const framework::ExecutionContext &ctx) const override { + auto *in_x_t = ctx.Input("X"); + auto *in_y_t = ctx.Input("Y"); + auto *dout = ctx.Input(framework::GradVarName("Out")); + auto *dx = ctx.Output(framework::GradVarName("X")); + auto *dy = ctx.Output(framework::GradVarName("Y")); + int axis = ctx.Attr("axis"); + auto dout_data = dout->data(); + + if (dx) { + auto dx_data = dx->mutable_data(ctx.GetPlace()); + for (size_t i = 0; i < dout->numel(); i++) { + dx_data[i] = dout_data[i]; + } } - } else { - auto x_dims = in_x_t->dims(); - auto y_dims = in_y_t->dims(); - - axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); - PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), - "Axis should be in range [0, x_dims)"); - - int pre, n, post; - GetMidDims get_mid_dims; - get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); - PADDLE_ENFORCE_EQ( - post, 1, "post should be equal 1, but received post is [%s]", post); - - for (size_t i = 0; i < SHARE_NUM; ++i) { - int y_offset = i * n; - for (size_t j = 0; j < pre; ++j) { - for (size_t k = 0; k < n; ++k) { - int out_offset = i * pre * n + j * n + k; - if (0 == j) { - dy_data[k + y_offset] = dout_data[out_offset]; - } else { - dy_data[k + y_offset] += dout_data[out_offset]; - } + + if (dy) { + auto dy_data = dy->mutable_data(ctx.GetPlace()); + if (in_x_t->dims().size() == in_y_t->dims().size()) { + for (size_t i = 0; i < dout->numel(); i++) { + dy_data[i] = dout_data[i]; + } + } else { + auto x_dims = in_x_t->dims(); + auto y_dims = in_y_t->dims(); + + axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + + int pre, n, post; + GetMidDims get_mid_dims; + get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); + PADDLE_ENFORCE_EQ(post, 1, + "post should be equal 1, but received post is [%s]", post); + + for (size_t i = 0; i < SHARE_NUM; ++i) { + int y_offset = i * n; + for (size_t j = 0; j < pre; ++j) { + for (size_t k = 0; k < n; ++k) { + int out_offset = i * pre * n + j * n + k; + if (0 == j) { + dy_data[k + y_offset] = dout_data[out_offset]; + } else { + dy_data[k + y_offset] += dout_data[out_offset]; + } + } + } + } } - } } - } } - } }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle + diff --git a/core/paddlefl_mpc/operators/mpc_elementwise_sub_op.cc b/core/paddlefl_mpc/operators/mpc_elementwise_sub_op.cc index 322b17238ef9c676559dc83f1ebd02a046025e70..bb3e373cfdc90fe052178bf75f2d56588b8b7880 100644 --- a/core/paddlefl_mpc/operators/mpc_elementwise_sub_op.cc +++ b/core/paddlefl_mpc/operators/mpc_elementwise_sub_op.cc @@ -1,123 +1,115 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ -#include "mpc_elementwise_sub_op.h" #include "paddle/fluid/framework/op_registry.h" +#include "mpc_elementwise_sub_op.h" namespace paddle { namespace operators { class MpcElementwiseSubOp : public framework::OperatorWithKernel { public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("X"), true, - platform::errors::NotFound( - "Input(X) of MpcElementwiseSubOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Y"), true, - platform::errors::NotFound( - "Input(Y) of MpcElementwiseSubOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("Out"), true, - platform::errors::NotFound( - "Output(Out) of MpcElementwiseSubOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->GetInputDim("X"), ctx->GetInputDim("Y"), - platform::errors::InvalidArgument( - "The dimensions of X should be equal with the dimensions of Y. " - "But received the dimensions of X is [%s], the dimensions of Y is " - "[%s]", - ctx->GetInputDim("X"), ctx->GetInputDim("Y"))); + using framework::OperatorWithKernel::OperatorWithKernel; - ctx->ShareDim("X", /*->*/ "Out"); - ctx->ShareLoD("X", /*->*/ "Out"); - } + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::NotFound("Input(X) of MpcElementwiseSubOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Y"), true, + platform::errors::NotFound("Input(Y) of MpcElementwiseSubOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + platform::errors::NotFound("Output(Out) of MpcElementwiseSubOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->GetInputDim("X"), ctx->GetInputDim("Y"), + platform::errors::InvalidArgument( + "The dimensions of X should be equal with the dimensions of Y. " + "But received the dimensions of X is [%s], the dimensions of Y is [%s]", + ctx->GetInputDim("X"), ctx->GetInputDim("Y"))); + + ctx->ShareDim("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); + } }; class MpcElementwiseSubOpMaker : public framework::OpProtoAndCheckerMaker { public: - void Make() override { - AddInput("X", - "(Tensor), The first input tensor of mpc elementwise sub op."); - AddInput("Y", - "(Tensor), The second input tensor of mpc elementwise sub op."); - AddOutput("Out", "(Tensor), The output tensor of mpc elementwise sub op."); - AddComment(R"DOC( + void Make() override { + AddInput("X", "(Tensor), The first input tensor of mpc elementwise sub op."); + AddInput("Y", "(Tensor), The second input tensor of mpc elementwise sub op."); + AddOutput("Out", "(Tensor), The output tensor of mpc elementwise sub op."); + AddComment(R"DOC( MPC elementwise sub Operator. )DOC"); - } + } }; class MpcElementwiseSubGradOp : public framework::OperatorWithKernel { public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - auto out_grad_name = framework::GradVarName("Out"); - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null."); - PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null."); - PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true, + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + auto out_grad_name = framework::GradVarName("Out"); + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true, "Input(Out@GRAD) should not be null."); - auto x_grad_name = framework::GradVarName("X"); - auto y_grad_name = framework::GradVarName("Y"); - if (ctx->HasOutput(x_grad_name)) { - ctx->ShareDim("X", /*->*/ x_grad_name); - ctx->ShareLoD("X", /*->*/ x_grad_name); + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) { + ctx->ShareDim("X", /*->*/ x_grad_name); + ctx->ShareLoD("X", /*->*/ x_grad_name); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->ShareDim("Y", /*->*/ y_grad_name); + ctx->ShareLoD("Y", /*->*/ y_grad_name); + } } - if (ctx->HasOutput(y_grad_name)) { - ctx->ShareDim("Y", /*->*/ y_grad_name); - ctx->ShareLoD("Y", /*->*/ y_grad_name); - } - } }; template -class MpcElementwiseSubGradMaker : public framework::SingleGradOpDescMaker { +class MpcElementwiseSubGradMaker : public framework::SingleGradOpMaker { public: - using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + using framework::SingleGradOpMaker::SingleGradOpMaker; protected: - std::unique_ptr Apply() const override { - std::unique_ptr retv(new T()); - retv->SetType("mpc_elementwise_sub_grad"); - retv->SetInput("X", this->Input("X")); - retv->SetInput("Y", this->Input("Y")); - retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); - retv->SetAttrMap(this->Attrs()); - return retv; - } + void Apply(GradOpPtr grad) const override { + grad->SetType("mpc_elementwise_sub_grad"); + grad->SetInput("X", this->Input("X")); + grad->SetInput("Y", this->Input("Y")); + grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + grad->SetAttrMap(this->Attrs()); + } }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(mpc_elementwise_sub, ops::MpcElementwiseSubOp, - ops::MpcElementwiseSubOpMaker, - ops::MpcElementwiseSubGradMaker); +REGISTER_OPERATOR(mpc_elementwise_sub, ops::MpcElementwiseSubOp, + ops::MpcElementwiseSubOpMaker, + ops::MpcElementwiseSubGradMaker); -REGISTER_OPERATOR(mpc_elementwise_sub_grad, ops::MpcElementwiseSubGradOp); +REGISTER_OPERATOR(mpc_elementwise_sub_grad, ops::MpcElementwiseSubGradOp); REGISTER_OP_CPU_KERNEL( - mpc_elementwise_sub, + mpc_elementwise_sub, ops::MpcElementwiseSubKernel); -REGISTER_OP_CPU_KERNEL(mpc_elementwise_sub_grad, - ops::MpcElementwiseSubGradKernel< - paddle::platform::CPUDeviceContext, int64_t>); +REGISTER_OP_CPU_KERNEL( + mpc_elementwise_sub_grad, + ops::MpcElementwiseSubGradKernel); diff --git a/core/paddlefl_mpc/operators/mpc_elementwise_sub_op.h b/core/paddlefl_mpc/operators/mpc_elementwise_sub_op.h index 3f467da245b9dfbdcf20303f79149dc5fe23b7ef..33ce6f83b9605de9150db5a234f539632f23935d 100644 --- a/core/paddlefl_mpc/operators/mpc_elementwise_sub_op.h +++ b/core/paddlefl_mpc/operators/mpc_elementwise_sub_op.h @@ -1,23 +1,22 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ // This op is different with elementwise_sub of PaddlePaddle. // We only consider that the dimensions of X is equal with the dimensions of Y. #pragma once #include "mpc_op.h" -#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h" namespace paddle { namespace operators { @@ -27,40 +26,39 @@ using Tensor = framework::Tensor; template class MpcElementwiseSubKernel : public MpcOpKernel { public: - void ComputeImpl(const framework::ExecutionContext &ctx) const override { - auto *in_x_t = ctx.Input("X"); - auto *in_y_t = ctx.Input("Y"); - auto *out_t = ctx.Output("Out"); + void ComputeImpl(const framework::ExecutionContext &ctx) const override { + auto *in_x_t = ctx.Input("X"); + auto *in_y_t = ctx.Input("Y"); + auto *out_t = ctx.Output("Out"); - auto out = out_t->mutable_data(ctx.GetPlace()); - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub( - in_x_t, in_y_t, out_t); - } + auto out = out_t->mutable_data(ctx.GetPlace()); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(in_x_t, in_y_t, out_t); + } }; template class MpcElementwiseSubGradKernel : public MpcOpKernel { public: - void ComputeImpl(const framework::ExecutionContext &ctx) const override { - VLOG(3) << "******** MpcElementwiseSubGradKernel: "; - auto *dout = ctx.Input(framework::GradVarName("Out")); - auto *dx = ctx.Output(framework::GradVarName("X")); - auto *dy = ctx.Output(framework::GradVarName("Y")); - auto dout_data = dout->data(); + void ComputeImpl(const framework::ExecutionContext &ctx) const override { + VLOG(3) << "******** MpcElementwiseSubGradKernel: "; + auto *dout = ctx.Input(framework::GradVarName("Out")); + auto *dx = ctx.Output(framework::GradVarName("X")); + auto *dy = ctx.Output(framework::GradVarName("Y")); + auto dout_data = dout->data(); - if (dx) { - auto dx_data = dx->mutable_data(ctx.GetPlace()); - for (size_t i = 0; i < dout->numel(); i++) { - dx_data[i] = dout_data[i]; - } - } - if (dy) { - dy->mutable_data(ctx.GetPlace()); - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neg( - dout, dy); + if (dx) { + auto dx_data = dx->mutable_data(ctx.GetPlace()); + for (size_t i = 0; i < dout->numel(); i++) { + dx_data[i] = dout_data[i]; + } + } + if (dy) { + dy->mutable_data(ctx.GetPlace()); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neg(dout, dy); + } } - } }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle + diff --git a/core/paddlefl_mpc/operators/mpc_init_op.cc b/core/paddlefl_mpc/operators/mpc_init_op.cc index 4ecc58c0005c18c39ee56584e2cf88c688ebc9c3..04a806dc5f33bc5dddfbc34ec0dacb8846a7b298 100644 --- a/core/paddlefl_mpc/operators/mpc_init_op.cc +++ b/core/paddlefl_mpc/operators/mpc_init_op.cc @@ -1,22 +1,22 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ // Description: #include "paddle/fluid/framework/op_registry.h" -#include "core/paddlefl_mpc/mpc_protocol/mpc_config.h" #include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h" +#include "core/paddlefl_mpc/mpc_protocol/mpc_config.h" namespace paddle { namespace operators { @@ -26,59 +26,63 @@ using mpc::Aby3Config; class MpcInitOp : public framework::OperatorBase { public: - MpcInitOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) {} - - void RunImpl(const framework::Scope &scope, - const platform::Place &dev_place) const override { - auto protocol_name = Attr("protocol_name"); - auto role = Attr("role"); - auto local_addr = Attr("local_addr"); - auto net_server_addr = Attr("net_server_addr"); - auto net_server_port = Attr("net_server_port"); - - MpcConfig _mpc_config; - _mpc_config.set_int(Aby3Config::ROLE, role); - _mpc_config.set(Aby3Config::LOCAL_ADDR, local_addr); - _mpc_config.set(Aby3Config::NET_SERVER_ADDR, net_server_addr); - _mpc_config.set_int(Aby3Config::NET_SERVER_PORT, net_server_port); - mpc::MpcInstance::init_instance(protocol_name, _mpc_config); - } + MpcInitOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + auto protocol_name = Attr("protocol_name"); + auto role = Attr("role"); + auto local_addr = Attr("local_addr"); + auto net_server_addr = Attr("net_server_addr"); + auto net_server_port = Attr("net_server_port"); + + MpcConfig _mpc_config; + _mpc_config.set_int(Aby3Config::ROLE, role); + _mpc_config.set(Aby3Config::LOCAL_ADDR, local_addr); + _mpc_config.set(Aby3Config::NET_SERVER_ADDR, net_server_addr); + _mpc_config.set_int(Aby3Config::NET_SERVER_PORT, net_server_port); + mpc::MpcInstance::init_instance(protocol_name, _mpc_config); + } }; class MpcInitOpMaker : public framework::OpProtoAndCheckerMaker { public: - void Make() override { + void Make() override { - AddComment(R"DOC( + AddComment(R"DOC( Where2 Operator. )DOC"); - AddAttr("protocol_name", "(string , default aby3)" - "protocol name") + AddAttr("protocol_name", + "(string , default aby3)" + "protocol name") .SetDefault({"aby3"}); - AddAttr("role", "trainer role.").SetDefault(0); - AddAttr("local_addr", "(string, default localhost)" - "local addr") + AddAttr("role", "trainer role.").SetDefault(0); + AddAttr("local_addr", + "(string, default localhost)" + "local addr") .SetDefault({"localhost"}); - AddAttr("net_server_addr", "(string, default localhost)" - "net server addr") + AddAttr("net_server_addr", + "(string, default localhost)" + "net server addr") .SetDefault({"localhost"}); - AddAttr("net_server_port", "net server port, default to 6539.") - .SetDefault(6539); - } + AddAttr("net_server_port", "net server port, default to 6539.").SetDefault(6539); + } }; class MpcInitOpShapeInference : public framework::InferShapeBase { -public: - void operator()(framework::InferShapeContext *ctx) const override {} + public: + void operator()(framework::InferShapeContext* ctx) const override {} }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(mpc_init, ops::MpcInitOp, ops::MpcInitOpMaker, - ops::MpcInitOpShapeInference); +REGISTER_OPERATOR( + mpc_init, ops::MpcInitOp, + ops::MpcInitOpMaker, ops::MpcInitOpShapeInference); + diff --git a/core/paddlefl_mpc/operators/mpc_mean_op.cc b/core/paddlefl_mpc/operators/mpc_mean_op.cc index 84deeb93d4fb4972d2dc3488685c599402915a0a..626fb204dd49a54d805fd8428ab52cb2fb4fa407 100644 --- a/core/paddlefl_mpc/operators/mpc_mean_op.cc +++ b/core/paddlefl_mpc/operators/mpc_mean_op.cc @@ -1,19 +1,19 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ -#include "mpc_mean_op.h" #include "paddle/fluid/framework/op_registry.h" +#include "mpc_mean_op.h" namespace paddle { namespace operators { @@ -22,78 +22,80 @@ using Tensor = framework::Tensor; class MpcMeanOp : public framework::OperatorWithKernel { public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::NotFound( - "Input(X) of MpcMeanOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::NotFound( - "Output(Out) of MpcMeanOp should not be null.")); - ctx->SetOutputDim("Out", {2, 1}); - } + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::NotFound("Input(X) of MpcMeanOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + platform::errors::NotFound("Output(Out) of MpcMeanOp should not be null.")); + ctx->SetOutputDim("Out", {2, 1}); + } }; class MpcMeanOpMaker : public framework::OpProtoAndCheckerMaker { public: - void Make() override { - AddInput("X", "(Tensor), The first input tensor of mpc mean op."); - AddOutput("Out", "(Tensor), The output tensor of mpc mean op."); - AddComment(R"DOC( + void Make() override { + AddInput("X", "(Tensor), The first input tensor of mpc mean op."); + AddOutput("Out", "(Tensor), The output tensor of mpc mean op."); + AddComment(R"DOC( MPC mean Operator calculates the mean of all elements in X. )DOC"); - } + } }; class MpcMeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: - std::unordered_map - GetInputOutputWithSameType() const override { - return std::unordered_map{{"X", /*->*/ "Out"}}; - } + std::unordered_map& GetInputOutputWithSameType() + const override { + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; + } }; class MpcMeanGradOp : public framework::OperatorWithKernel { public: - using framework::OperatorWithKernel::OperatorWithKernel; - using Tensor = framework::Tensor; + using framework::OperatorWithKernel::OperatorWithKernel; + using Tensor = framework::Tensor; + + void InferShape(framework::InferShapeContext *ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", framework::GradVarName("X")); + } - void InferShape(framework::InferShapeContext *ctx) const override { - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); - ctx->ShareLoD("X", framework::GradVarName("X")); - } }; template -class MpcMeanOpGradMaker : public framework::SingleGradOpDescMaker { +class MpcMeanOpGradMaker : public framework::SingleGradOpMaker { public: - using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + using framework::SingleGradOpMaker::SingleGradOpMaker; protected: - std::unique_ptr Apply() const override { - std::unique_ptr retv(new T()); - retv->SetType("mpc_mean_grad"); - retv->SetInput("X", this->Input("X")); - retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - return retv; - } + void Apply(GradOpPtr grad) const override { + grad->SetType("mpc_mean_grad"); + grad->SetInput("X", this->Input("X")); + grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(mpc_mean, ops::MpcMeanOp, ops::MpcMeanOpMaker, - ops::MpcMeanOpInferVarType, - ops::MpcMeanOpGradMaker); +REGISTER_OPERATOR(mpc_mean, ops::MpcMeanOp, + ops::MpcMeanOpMaker, + ops::MpcMeanOpInferVarType, + ops::MpcMeanOpGradMaker); REGISTER_OPERATOR(mpc_mean_grad, ops::MpcMeanGradOp); REGISTER_OP_CPU_KERNEL( - mpc_mean, ops::MpcMeanKernel); + mpc_mean, + ops::MpcMeanKernel); REGISTER_OP_CPU_KERNEL( - mpc_mean_grad, + mpc_mean_grad, ops::MpcMeanGradKernel); diff --git a/core/paddlefl_mpc/operators/mpc_mean_op.h b/core/paddlefl_mpc/operators/mpc_mean_op.h index a6338c85889df24a8776f4232db48f7bb294f871..e921997370fd76030fb9c5410bc9aa3da24b69db 100644 --- a/core/paddlefl_mpc/operators/mpc_mean_op.h +++ b/core/paddlefl_mpc/operators/mpc_mean_op.h @@ -1,20 +1,19 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once #include "mpc_op.h" -#include "paddle/fluid/framework/eigen.h" namespace paddle { namespace operators { @@ -28,43 +27,40 @@ using EigenVector = framework::EigenVector; template class MpcMeanKernel : public MpcOpKernel { public: - void ComputeImpl(const framework::ExecutionContext &ctx) const override { - auto *in_x_t = ctx.Input("X"); - auto *out_t = ctx.Output("Out"); - out_t->mutable_data(ctx.GetPlace()); - double scale = 1.0 / (in_x_t->numel() / 2.0); - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sum( - in_x_t, out_t); - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale( - out_t, scale, out_t); - } + void ComputeImpl(const framework::ExecutionContext &ctx) const override { + auto *in_x_t = ctx.Input("X"); + auto *out_t = ctx.Output("Out"); + out_t->mutable_data(ctx.GetPlace()); + double scale = 1.0 / (in_x_t->numel() / 2.0); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sum(in_x_t, out_t); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(out_t, scale, out_t); + } }; template class MpcMeanGradKernel : public MpcOpKernel { public: - void ComputeImpl(const framework::ExecutionContext &ctx) const override { - auto dout = ctx.Input(framework::GradVarName("Out")); - PADDLE_ENFORCE(dout->numel() == 2, - "numel of MpcMean Gradient should be 2."); - auto dx = ctx.Output(framework::GradVarName("X")); - auto dout_data = dout->data(); - - if (dx) { - auto dx_data = dx->mutable_data(ctx.GetPlace()); - for (size_t i = 0; i < dx->numel() / 2; ++i) { - dx_data[i] = dout_data[0]; - } - for (size_t i = dx->numel() / 2; i < dx->numel(); ++i) { - dx_data[i] = dout_data[1]; - } + void ComputeImpl(const framework::ExecutionContext &ctx) const override { + auto dout = ctx.Input(framework::GradVarName("Out")); + PADDLE_ENFORCE(dout->numel() == 2, "numel of MpcMean Gradient should be 2."); + auto dx = ctx.Output(framework::GradVarName("X")); + auto dout_data = dout->data(); - double scale_factor = 1.0 / (dx->numel() / 2); - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale( - dx, scale_factor, dx); + if (dx) { + auto dx_data = dx->mutable_data(ctx.GetPlace()); + for (size_t i = 0; i < dx->numel() / 2; ++i) { + dx_data[i] = dout_data[0]; + } + for (size_t i = dx->numel() / 2; i < dx->numel(); ++i) { + dx_data[i] = dout_data[1]; + } + + double scale_factor = 1.0 / (dx->numel() / 2); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(dx, scale_factor, dx); + } } - } }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle + diff --git a/core/paddlefl_mpc/operators/mpc_mul_op.cc b/core/paddlefl_mpc/operators/mpc_mul_op.cc index bc03d35f9833012b1726b3bb47ff013db079c560..f12b069e96ef23e11e9df7447b7b2b80ccde9bbf 100644 --- a/core/paddlefl_mpc/operators/mpc_mul_op.cc +++ b/core/paddlefl_mpc/operators/mpc_mul_op.cc @@ -1,19 +1,19 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ -#include "mpc_mul_op.h" #include "paddle/fluid/framework/op_registry.h" +#include "mpc_mul_op.h" namespace paddle { namespace operators { @@ -22,98 +22,98 @@ using Tensor = framework::Tensor; class MpcMulOp : public framework::OperatorWithKernel { public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::NotFound( - "Input(X) of Mpc MulOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Y"), true, - platform::errors::NotFound("Input(Y) of MpcMulOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::NotFound( - "Output(Out) of MpcMulOp should not be null.")); - - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - - int x_num_col_dims = ctx->Attrs().Get("x_num_col_dims"); - int y_num_col_dims = ctx->Attrs().Get("y_num_col_dims"); - - VLOG(3) << "mpc mul operator x.shape=" << x_dims << " y.shape=" << y_dims - << " x_num_col_dims=" << x_num_col_dims - << " y_num_col_dims=" << y_num_col_dims; - - PADDLE_ENFORCE_NE(framework::product(y_dims), 0, - platform::errors::PreconditionNotMet( - "The Input variable Y(%s) has not " - "been initialized. You may need to confirm " - "if you put exe.run(startup_program) " - "after optimizer.minimize function.", - ctx->Inputs("Y").front())); - PADDLE_ENFORCE_GT( - x_dims.size(), x_num_col_dims, - platform::errors::InvalidArgument( - "The input tensor X's dimensions of MpcMulOp " - "should be larger than x_num_col_dims. But received X's " - "dimensions = %d, X's shape = [%s], x_num_col_dims = %d.", - x_dims.size(), x_dims, x_num_col_dims)); - PADDLE_ENFORCE_GT( - y_dims.size(), y_num_col_dims, - platform::errors::InvalidArgument( - "The input tensor Y's dimensions of MpcMulOp " - "should be larger than y_num_col_dims. But received Y's " - "dimensions = %d, Y's shape = [%s], y_num_col_dims = %d.", - y_dims.size(), y_dims, y_num_col_dims)); - - int x_mat_width = 1; - int y_mat_height = 1; - for (size_t i = x_num_col_dims + 1; i < x_dims.size(); i++) { - x_mat_width *= x_dims[i]; - } - for (size_t i = 1; i <= y_num_col_dims; i++) { - y_mat_height *= y_dims[i]; - } - - PADDLE_ENFORCE_EQ( - x_mat_width, y_mat_height, - platform::errors::InvalidArgument( - "After flatten the input tensor X and Y to 2-D dimensions " - "matrix X1 and Y1, the matrix X1's width must be equal with matrix " - "Y1's height. But received X's shape = [%s], X1's " - "width = %s; Y's shape = [%s], Y1's height = %s.", - x_dims, x_mat_width, y_dims, y_mat_height)); - - std::vector output_dims; - output_dims.reserve(static_cast(1 + x_num_col_dims + y_dims.size() - - y_num_col_dims)); - - for (int i = 0; i <= x_num_col_dims; ++i) { // i=0, batch_size (share id) - output_dims.push_back(x_dims[i]); + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::NotFound("Input(X) of Mpc MulOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Y"), true, + platform::errors::NotFound("Input(Y) of MpcMulOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + platform::errors::NotFound("Output(Out) of MpcMulOp should not be null.")); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + + int x_num_col_dims = ctx->Attrs().Get("x_num_col_dims"); + int y_num_col_dims = ctx->Attrs().Get("y_num_col_dims"); + + VLOG(3) << "mpc mul operator x.shape=" << x_dims << " y.shape=" << y_dims + << " x_num_col_dims=" << x_num_col_dims + << " y_num_col_dims=" << y_num_col_dims; + + PADDLE_ENFORCE_NE(framework::product(y_dims), 0, + platform::errors::PreconditionNotMet( + "The Input variable Y(%s) has not " + "been initialized. You may need to confirm " + "if you put exe.run(startup_program) " + "after optimizer.minimize function.", + ctx->Inputs("Y").front())); + PADDLE_ENFORCE_GT( + x_dims.size(), x_num_col_dims, + platform::errors::InvalidArgument( + "The input tensor X's dimensions of MpcMulOp " + "should be larger than x_num_col_dims. But received X's " + "dimensions = %d, X's shape = [%s], x_num_col_dims = %d.", + x_dims.size(), x_dims, x_num_col_dims)); + PADDLE_ENFORCE_GT( + y_dims.size(), y_num_col_dims, + platform::errors::InvalidArgument( + "The input tensor Y's dimensions of MpcMulOp " + "should be larger than y_num_col_dims. But received Y's " + "dimensions = %d, Y's shape = [%s], y_num_col_dims = %d.", + y_dims.size(), y_dims, y_num_col_dims)); + + int x_mat_width = 1; + int y_mat_height = 1; + for (size_t i = x_num_col_dims + 1; i < x_dims.size(); i++) { + x_mat_width *= x_dims[i]; + } + for (size_t i = 1; i <= y_num_col_dims; i++) { + y_mat_height *= y_dims[i]; + } + + PADDLE_ENFORCE_EQ( + x_mat_width, y_mat_height, + platform::errors::InvalidArgument( + "After flatten the input tensor X and Y to 2-D dimensions " + "matrix X1 and Y1, the matrix X1's width must be equal with matrix " + "Y1's height. But received X's shape = [%s], X1's " + "width = %s; Y's shape = [%s], Y1's height = %s.", + x_dims, x_mat_width, y_dims, y_mat_height)); + + std::vector output_dims; + output_dims.reserve( + static_cast(1 + x_num_col_dims + y_dims.size() - y_num_col_dims)); + + for (int i = 0; i <= x_num_col_dims; ++i) { // i=0, batch_size (share id) + output_dims.push_back(x_dims[i]); + } + + for (int i = y_num_col_dims + 1; i < y_dims.size(); ++i) { + output_dims.push_back(y_dims[i]); + } + + ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); + ctx->ShareLoD("X", /*->*/ "Out"); } - - for (int i = y_num_col_dims + 1; i < y_dims.size(); ++i) { - output_dims.push_back(y_dims[i]); - } - - ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); - ctx->ShareLoD("X", /*->*/ "Out"); - } }; class MpcMulOpMaker : public framework::OpProtoAndCheckerMaker { public: - void Make() override { - AddInput("X", "(Tensor), The first input tensor of mpc mul op."); - AddInput("Y", "(Tensor), The second input tensor of mpc mul op."); - AddOutput("Out", "(Tensor), The output tensor of mpc mul op."); - AddAttr("use_mkldnn", - "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); - AddAttr( - "x_num_col_dims", - R"DOC((int, default 1), The mul_op can take tensors with more than two + void Make() override { + AddInput("X", "(Tensor), The first input tensor of mpc mul op."); + AddInput("Y", "(Tensor), The second input tensor of mpc mul op."); + AddOutput("Out", "(Tensor), The output tensor of mpc mul op."); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr( + "x_num_col_dims", + R"DOC((int, default 1), The mul_op can take tensors with more than two dimensions as its inputs. If the input $X$ is a tensor with more than two dimensions, $X$ will be flattened into a two-dimensional matrix first. The flattening rule is: the first `num_col_dims` @@ -129,109 +129,112 @@ public: Thus, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = [24, 30]. )DOC") - .SetDefault(1) - .EqualGreaterThan(1); - AddAttr( - "y_num_col_dims", - R"DOC((int, default 1), The mul_op can take tensors with more than two, + .SetDefault(1) + .EqualGreaterThan(1); + AddAttr( + "y_num_col_dims", + R"DOC((int, default 1), The mul_op can take tensors with more than two, dimensions as its inputs. If the input $Y$ is a tensor with more than two dimensions, $Y$ will be flattened into a two-dimensional matrix first. The attribute `y_num_col_dims` determines how $Y$ is flattened. See comments of `x_num_col_dims` for more details. )DOC") - .SetDefault(1) - .EqualGreaterThan(1); - AddAttr( - "scale_x", - "scale_x to be used for int8 mul input data x. scale_x has the" - "same purpose as scale_in in OPs that support quantization." - "Only to be used with MKL-DNN INT8") - .SetDefault(1.0f); - AddAttr>( - "scale_y", - "scale_y to be used for int8 mul input data y. scale_y has the" - "same purpose as scale_weights in OPs that support quantization." - "Only to be used with MKL-DNN INT8") - .SetDefault({1.0f}); - AddAttr("scale_out", "scale_out to be used for int8 output data." - "Only used with MKL-DNN INT8") - .SetDefault(1.0f); - AddAttr( - "force_fp32_output", - "(bool, default false) Force quantize kernel output FP32, only " - "used in quantized MKL-DNN.") - .SetDefault(false); - AddComment(R"DOC( + .SetDefault(1) + .EqualGreaterThan(1); + AddAttr( + "scale_x", + "scale_x to be used for int8 mul input data x. scale_x has the" + "same purpose as scale_in in OPs that support quantization." + "Only to be used with MKL-DNN INT8") + .SetDefault(1.0f); + AddAttr>( + "scale_y", + "scale_y to be used for int8 mul input data y. scale_y has the" + "same purpose as scale_weights in OPs that support quantization." + "Only to be used with MKL-DNN INT8") + .SetDefault({1.0f}); + AddAttr("scale_out", + "scale_out to be used for int8 output data." + "Only used with MKL-DNN INT8") + .SetDefault(1.0f); + AddAttr( + "force_fp32_output", + "(bool, default false) Force quantize kernel output FP32, only " + "used in quantized MKL-DNN.") + .SetDefault(false); + AddComment(R"DOC( MPC mul Operator. )DOC"); - } + } }; class MpcMulOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: - std::unordered_map - GetInputOutputWithSameType() const override { - return std::unordered_map{{"X", /*->*/ "Out"}}; - } + std::unordered_map& GetInputOutputWithSameType() + const override { + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; + } }; class MpcMulGradOp : public framework::OperatorWithKernel { public: - using framework::OperatorWithKernel::OperatorWithKernel; - using Tensor = framework::Tensor; - - void InferShape(framework::InferShapeContext *ctx) const override { - auto out_grad_name = framework::GradVarName("Out"); - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null."); - PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null."); - PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true, - "Input(Out@GRAD) should not be null."); - auto x_grad_name = framework::GradVarName("X"); - auto y_grad_name = framework::GradVarName("Y"); - - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); - } - if (ctx->HasOutput(y_grad_name)) { - ctx->SetOutputDim(y_grad_name, y_dims); + using framework::OperatorWithKernel::OperatorWithKernel; + using Tensor = framework::Tensor; + + void InferShape(framework::InferShapeContext *ctx) const override { + auto out_grad_name = framework::GradVarName("Out"); + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true, + "Input(Out@GRAD) should not be null."); + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } } - } }; template -class MpcMulOpGradMaker : public framework::SingleGradOpDescMaker { +class MpcMulOpGradMaker : public framework::SingleGradOpMaker { public: - using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + using framework::SingleGradOpMaker::SingleGradOpMaker; protected: - std::unique_ptr Apply() const override { - std::unique_ptr retv(new T()); - retv->SetType("mpc_mul_grad"); - retv->SetInput("X", this->Input("X")); - retv->SetInput("Y", this->Input("Y")); - retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); - retv->SetAttrMap(this->Attrs()); - return retv; - } + void Apply(GradOpPtr grad) const override { + grad->SetType("mpc_mul_grad"); + grad->SetInput("X", this->Input("X")); + grad->SetInput("Y", this->Input("Y")); + grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + grad->SetAttrMap(this->Attrs()); + } }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; -REGISTER_OPERATOR(mpc_mul, ops::MpcMulOp, ops::MpcMulOpMaker, - ops::MpcMulOpInferVarType, - ops::MpcMulOpGradMaker); +REGISTER_OPERATOR(mpc_mul, ops::MpcMulOp, + ops::MpcMulOpMaker, + ops::MpcMulOpInferVarType, + ops::MpcMulOpGradMaker); REGISTER_OPERATOR(mpc_mul_grad, ops::MpcMulGradOp); REGISTER_OP_CPU_KERNEL( - mpc_mul, ops::MpcMulKernel); + mpc_mul, + ops::MpcMulKernel); REGISTER_OP_CPU_KERNEL( mpc_mul_grad, diff --git a/core/paddlefl_mpc/operators/mpc_mul_op.h b/core/paddlefl_mpc/operators/mpc_mul_op.h index 93498d3901e7d322c4b90d3a22774d8aafca1f93..67a8b065ada96eb316d10f4114c125005b812e50 100644 --- a/core/paddlefl_mpc/operators/mpc_mul_op.h +++ b/core/paddlefl_mpc/operators/mpc_mul_op.h @@ -1,20 +1,19 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once #include "mpc_op.h" -#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h" namespace paddle { namespace operators { @@ -24,185 +23,170 @@ using Tensor = framework::Tensor; template class MpcMulKernel : public MpcOpKernel { public: - void ComputeImpl(const framework::ExecutionContext &ctx) const override { - auto *x = ctx.Input("X"); - auto *y = ctx.Input("Y"); - auto *out = ctx.Output("Out"); - - int x_num_col_dims = ctx.template Attr("x_num_col_dims"); - int y_num_col_dims = ctx.template Attr("y_num_col_dims"); - auto x_dims = x->dims(); - auto y_dims = y->dims(); - - int x_mat_width = 1; - int x_mat_height = 1; - int y_mat_width = 1; - int y_mat_height = 1; - - for (size_t i = 1; i < x_dims.size(); i++) { - if (i <= x_num_col_dims) { - x_mat_width *= x_dims[i]; - } else { - x_mat_height *= x_dims[i]; - } - } - for (size_t i = 1; i < y_dims.size(); i++) { - if (i <= y_num_col_dims) { - x_mat_width *= y_dims[i]; - } else { - y_mat_height *= y_dims[i]; - } - } - - Tensor x_matrix; - Tensor y_matrix; - x_matrix.ShareDataWith(*x); - y_matrix.ShareDataWith(*y); - - if (x_dims.size() > 3) { - x_matrix.Resize({2, x_mat_width, x_mat_height}); + void ComputeImpl(const framework::ExecutionContext &ctx) const override { + auto *x = ctx.Input("X"); + auto *y = ctx.Input("Y"); + auto *out = ctx.Output("Out"); + + int x_num_col_dims = ctx.template Attr("x_num_col_dims"); + int y_num_col_dims = ctx.template Attr("y_num_col_dims"); + auto x_dims = x->dims(); + auto y_dims = y->dims(); + + int x_mat_width = 1; + int x_mat_height = 1; + int y_mat_width = 1; + int y_mat_height = 1; + + for (size_t i = 1; i < x_dims.size(); i++) { + if (i <= x_num_col_dims) { + x_mat_width *= x_dims[i]; + } else { + x_mat_height *= x_dims[i]; + } + } + for (size_t i = 1; i < y_dims.size(); i++) { + if (i <= y_num_col_dims) { + y_mat_width *= y_dims[i]; + } else { + y_mat_height *= y_dims[i]; + } + } + + Tensor x_matrix; + Tensor y_matrix; + x_matrix.ShareDataWith(*x); + y_matrix.ShareDataWith(*y); + + x_matrix.Resize({2, x_mat_width, x_mat_height}); + y_matrix.Resize({2, y_mat_width, y_mat_height}); + + out->mutable_data(ctx.GetPlace()); + + auto out_dim = out->dims(); + if (out_dim.size() > 3) { + out->Resize({2, x_mat_width, y_mat_height}); + } + + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul( + &x_matrix, &y_matrix, out); + + if (out_dim.size() > 3) { + out->Resize(out_dim); + } + } - - if (y_dims.size() > 3) { - y_matrix.Resize({2, y_mat_width, y_mat_height}); - } - - out->mutable_data(ctx.GetPlace()); - - auto out_dim = out->dims(); - if (out_dim.size() > 3) { - out->Resize({2, x_mat_width, y_mat_height}); - } - - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul( - &x_matrix, &y_matrix, out); - - if (out_dim.size() > 3) { - out->Resize(out_dim); - } - } }; + template class MpcMulGradKernel : public MpcOpKernel { public: - void ComputeImpl(const framework::ExecutionContext &ctx) const override { - auto *x = ctx.Input("X"); - auto *y = ctx.Input("Y"); - auto *dout = ctx.Input(framework::GradVarName("Out")); - auto *dx = ctx.Output(framework::GradVarName("X")); - auto *dy = ctx.Output(framework::GradVarName("Y")); - int x_num_col_dims = ctx.template Attr("x_num_col_dims"); - int y_num_col_dims = ctx.template Attr("y_num_col_dims"); - auto x_dims = x->dims(); - auto y_dims = y->dims(); - auto dout_dims = dout->dims(); - - int x_mat_width = 1; - int x_mat_height = 1; - int y_mat_width = 1; - int y_mat_height = 1; - - for (size_t i = 1; i < x_dims.size(); i++) { - if (i <= x_num_col_dims) { - x_mat_width *= x_dims[i]; - } else { - x_mat_height *= x_dims[i]; - } - } - for (size_t i = 1; i < y_dims.size(); i++) { - if (i <= y_num_col_dims) { - y_mat_width *= y_dims[i]; - } else { - y_mat_height *= y_dims[i]; - } + void ComputeImpl(const framework::ExecutionContext &ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto *dout = ctx.Input(framework::GradVarName("Out")); + auto *dx = ctx.Output(framework::GradVarName("X")); + auto *dy = ctx.Output(framework::GradVarName("Y")); + int x_num_col_dims = ctx.template Attr("x_num_col_dims"); + int y_num_col_dims = ctx.template Attr("y_num_col_dims"); + auto x_dims = x->dims(); + auto y_dims = y->dims(); + auto dout_dims = dout->dims(); + + int x_mat_width = 1; + int x_mat_height = 1; + int y_mat_width = 1; + int y_mat_height = 1; + + for (size_t i = 1; i < x_dims.size(); i++) { + if (i <= x_num_col_dims) { + x_mat_width *= x_dims[i]; + } else { + x_mat_height *= x_dims[i]; + } + } + for (size_t i = 1; i < y_dims.size(); i++) { + if (i <= y_num_col_dims) { + y_mat_width *= y_dims[i]; + } else { + y_mat_height *= y_dims[i]; + } + } + + Tensor x_matrix; + Tensor y_matrix; + Tensor dout_matrix; + x_matrix.ShareDataWith(*x); + y_matrix.ShareDataWith(*y); + dout_matrix.ShareDataWith(*dout); + + x_matrix.Resize({2, x_mat_width, x_mat_height}); + y_matrix.Resize({2, y_mat_width, y_mat_height}); + dout_matrix.Resize({2, x_mat_width, y_mat_height}); + + if (dx != nullptr) { + dx->set_lod(x->lod()); + } + if (dy != nullptr) { + dy->set_lod(y->lod()); + } + + Tensor x_matrix_trans; + Tensor y_matrix_trans; + x_matrix_trans.mutable_data(x->dims(), ctx.GetPlace()); + y_matrix_trans.mutable_data(y->dims(), ctx.GetPlace()); + + x_matrix_trans.Resize({2, x_mat_height, x_mat_width}); + y_matrix_trans.Resize({2, y_mat_height, y_mat_width}); + + auto& dev_ctx = ctx.template device_context(); + const int Rank = 3; + + Eigen::array permute; + permute[0] = 0; + permute[1] = 2; + permute[2] = 1; + + if (dx) { + dx->mutable_data(ctx.GetPlace()); + if (dx->dims().size() > 3) { + dx->Resize({2, x_mat_width, x_mat_height}); + } + auto eigen_in = framework::EigenTensor::From(y_matrix); + auto eigen_out = framework::EigenTensor::From(y_matrix_trans); + auto* dev = dev_ctx.eigen_device(); + eigen_out.device(*dev) = eigen_in.shuffle(permute); + // dx = dout * y'. dx: M x K, dout : M x N, y : K x N + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul( + &dout_matrix, &y_matrix_trans, dx); + auto dx_dim = dx->dims(); + if (dx_dim.size() > 3) { + dx->Resize(dx_dim); + } + } + + if (dy) { + dy->mutable_data(ctx.GetPlace()); + if (dy->dims().size() > 3) { + dy->Resize({2, y_mat_width, y_mat_height}); + } + + auto eigen_in = framework::EigenTensor::From(x_matrix); + auto eigen_out = framework::EigenTensor::From(x_matrix_trans); + auto* dev = dev_ctx.eigen_device(); + eigen_out.device(*dev) = eigen_in.shuffle(permute); + // dy = x' * dout. dy K x N, dout : M x N, x : M x K + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul( + &x_matrix_trans, &dout_matrix, dy); + auto dy_dim = dy->dims(); + if (dy_dim.size() > 3) { + dy->Resize(dy_dim); + } + } } - - Tensor x_matrix; - Tensor y_matrix; - Tensor dout_matrix; - x_matrix.ShareDataWith(*x); - y_matrix.ShareDataWith(*y); - dout_matrix.ShareDataWith(*dout); - - if (x_dims.size() > 3) { - x_matrix.Resize({2, x_mat_width, x_mat_height}); - } - - if (y_dims.size() > 3) { - y_matrix.Resize({2, y_mat_width, y_mat_height}); - } - - if (dout_dims.size() > 3) { - dout_matrix.Resize({2, x_mat_width, y_mat_height}); - } - - if (dx != nullptr) { - dx->set_lod(x->lod()); - } - if (dy != nullptr) { - dy->set_lod(y->lod()); - } - - Tensor x_matrix_trans; - Tensor y_matrix_trans; - x_matrix_trans.mutable_data(x->dims(), ctx.GetPlace()); - y_matrix_trans.mutable_data(y->dims(), ctx.GetPlace()); - - if (x_dims.size() >= 3) { - x_matrix_trans.Resize({2, x_mat_height, x_mat_width}); - } - - if (y_dims.size() >= 3) { - y_matrix_trans.Resize({2, y_mat_height, y_mat_width}); - } - - auto &dev_ctx = ctx.template device_context(); - const int Rank = 3; - - Eigen::array permute; - permute[0] = 0; - permute[1] = 2; - permute[2] = 1; - - if (dx) { - dx->mutable_data(ctx.GetPlace()); - if (dx->dims().size() > 3) { - dx->Resize({2, x_mat_width, x_mat_height}); - } - auto eigen_in = framework::EigenTensor::From(y_matrix); - auto eigen_out = framework::EigenTensor::From(y_matrix_trans); - auto *dev = dev_ctx.eigen_device(); - eigen_out.device(*dev) = eigen_in.shuffle(permute); - // dx = dout * y'. dx: M x K, dout : M x N, y : K x N - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul( - &dout_matrix, &y_matrix_trans, dx); - auto dx_dim = dx->dims(); - if (dx_dim.size() > 3) { - dx->Resize(dx_dim); - } - } - - if (dy) { - dy->mutable_data(ctx.GetPlace()); - if (dy->dims().size() > 3) { - dy->Resize({2, y_mat_width, y_mat_height}); - } - - auto eigen_in = framework::EigenTensor::From(x_matrix); - auto eigen_out = framework::EigenTensor::From(x_matrix_trans); - auto *dev = dev_ctx.eigen_device(); - eigen_out.device(*dev) = eigen_in.shuffle(permute); - // dy = x' * dout. dy K x N, dout : M x N, x : M x K - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul( - &x_matrix_trans, &dout_matrix, dy); - auto dy_dim = dy->dims(); - if (dy_dim.size() > 3) { - dy->Resize(dy_dim); - } - } - } }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle + diff --git a/core/paddlefl_mpc/operators/mpc_op.h b/core/paddlefl_mpc/operators/mpc_op.h index 86b14b761dd7f98c4ed45184da9c04d4e455f74f..6cff543b0dbf03f13a9e8baeb83335986f5fe249 100644 --- a/core/paddlefl_mpc/operators/mpc_op.h +++ b/core/paddlefl_mpc/operators/mpc_op.h @@ -1,43 +1,45 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ // Description: #pragma once #include "paddle/fluid/framework/operator.h" -#include "core/paddlefl_mpc/mpc_protocol/context_holder.h" #include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h" +#include "core/paddlefl_mpc/mpc_protocol/context_holder.h" #include "core/privc3/circuit_context.h" namespace paddle { namespace operators { -template class MpcOpKernel : public framework::OpKernelBase { +template +class MpcOpKernel : public framework::OpKernelBase { public: - using ELEMENT_TYPE = T; - void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_instance()->mpc_protocol(), - "Mpc protocol is not yet initialized in executor"); - - std::shared_ptr mpc_ctx( - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context()); - mpc::ContextHolder::template run_with_context<>(&ctx, mpc_ctx, - [&] { ComputeImpl(ctx); }); - } - virtual void ComputeImpl(const framework::ExecutionContext &ctx) const = 0; + using ELEMENT_TYPE = T; + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_instance()->mpc_protocol(), + "Mpc protocol is not yet initialized in executor"); + + std::shared_ptr mpc_ctx(mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context()); + mpc::ContextHolder::template run_with_context<>(&ctx, mpc_ctx, + [&] { ComputeImpl(ctx); }); + } + virtual void ComputeImpl(const framework::ExecutionContext& ctx) const = 0; }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle + + diff --git a/core/paddlefl_mpc/operators/mpc_relu_op.cc b/core/paddlefl_mpc/operators/mpc_relu_op.cc index e5ffb37cb709846b6f624c62ac6335381d0be428..420abeb635088174465c8e764ca522f8ec6f91b2 100644 --- a/core/paddlefl_mpc/operators/mpc_relu_op.cc +++ b/core/paddlefl_mpc/operators/mpc_relu_op.cc @@ -18,20 +18,20 @@ namespace paddle { namespace operators { -// forward op defination +//forward op defination class MpcReluOp : public framework::OperatorWithKernel { -public: + public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { auto in_dims = ctx->GetInputDim("X"); ctx->SetOutputDim("Y", in_dims); } }; -// forward input & output defination +//forward input & output defination class MpcReluOpMaker : public framework::OpProtoAndCheckerMaker { -public: + public: void Make() override { AddInput("X", "The input tensor."); AddOutput("Y", "Output of relu_op"); @@ -41,43 +41,46 @@ Mpc Relu Operator. } }; -// backward op defination +//backward op defination class MpcReluGradOp : public framework::OperatorWithKernel { -public: + public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { auto in_dims = ctx->GetInputDim(framework::GradVarName("Y")); ctx->SetOutputDim(framework::GradVarName("X"), in_dims); } }; -// backward type, input & output defination +//backward type, input & output defination template -class MpcReluGradMaker : public framework::SingleGradOpDescMaker { +class MpcReluGradMaker : public framework::SingleGradOpMaker { public: - using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + using framework::SingleGradOpMaker::SingleGradOpMaker; - std::unique_ptr Apply() const override { - auto *op = new T(); - op->SetType("mpc_relu_grad"); - op->SetInput("Y", this->Output("Y")); - op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); - op->SetAttrMap(this->Attrs()); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - return std::unique_ptr(op); - } +protected: + void Apply(GradOpPtr grad) const override { + grad->SetType("mpc_relu_grad"); + grad->SetInput("Y", this->Output("Y")); + grad->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); + grad->SetAttrMap(this->Attrs()); + grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; using CPU = paddle::platform::CPUDeviceContext; -REGISTER_OPERATOR(mpc_relu, ops::MpcReluOp, ops::MpcReluOpMaker, +REGISTER_OPERATOR(mpc_relu, + ops::MpcReluOp, + ops::MpcReluOpMaker, ops::MpcReluGradMaker); REGISTER_OPERATOR(mpc_relu_grad, ops::MpcReluGradOp); -REGISTER_OP_CPU_KERNEL(mpc_relu, ops::MpcReluKernel); -REGISTER_OP_CPU_KERNEL(mpc_relu_grad, ops::MpcReluGradKernel); +REGISTER_OP_CPU_KERNEL(mpc_relu, + ops::MpcReluKernel); +REGISTER_OP_CPU_KERNEL(mpc_relu_grad, + ops::MpcReluGradKernel); diff --git a/core/paddlefl_mpc/operators/mpc_relu_op.h b/core/paddlefl_mpc/operators/mpc_relu_op.h index 73bcd92ce58e035069ef26c4cc19224e0f673bdb..f0a39264513bbfbfe4a9ba162d5bace2235b559b 100644 --- a/core/paddlefl_mpc/operators/mpc_relu_op.h +++ b/core/paddlefl_mpc/operators/mpc_relu_op.h @@ -14,43 +14,37 @@ #pragma once #include "mpc_op.h" -#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -// Define forward computation +//Define forward computation template class MpcReluKernel : public MpcOpKernel { public: - void ComputeImpl(const framework::ExecutionContext &ctx) const override { - const Tensor *in_t = ctx.Input("X"); - Tensor *out_t = ctx.Output("Y"); - auto x = in_t->data(); - auto y = out_t->mutable_data(ctx.GetPlace()); - PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, - "Protocol %s is not yet created in MPC Protocol."); - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->relu( - in_t, out_t); + void ComputeImpl(const framework::ExecutionContext& ctx) const override { + const Tensor* in_t = ctx.Input("X"); + Tensor* out_t = ctx.Output("Y"); + auto x = in_t->data(); + auto y = out_t->mutable_data(ctx.GetPlace()); + PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, "Protocol %s is not yet created in MPC Protocol."); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->relu(in_t,out_t); } }; -// Define backward computation +//Define backward computation template class MpcReluGradKernel : public MpcOpKernel { public: - void ComputeImpl(const framework::ExecutionContext &ctx) const override { - auto *dy_t = ctx.Input(framework::GradVarName("Y")); - auto *y_t = ctx.Input("Y"); - auto *dx_t = ctx.Output(framework::GradVarName("X")); - auto dx = dx_t->mutable_data(ctx.GetPlace()); - mpc::MpcInstance::mpc_instance() - ->mpc_protocol() - ->mpc_operators() - ->relu_grad(y_t, dy_t, dx_t, 0.0); - } + void ComputeImpl(const framework::ExecutionContext& ctx) const override { + auto* dy_t = ctx.Input(framework::GradVarName("Y")); + auto* y_t = ctx.Input("Y"); + auto* dx_t = ctx.Output(framework::GradVarName("X")); + auto dx = dx_t->mutable_data(ctx.GetPlace()); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->relu_grad(y_t, dy_t, dx_t, 0.0); + } }; -} // namespace operaters -} // namespace paddle +}// namespace operaters +}// namespace paddle diff --git a/core/paddlefl_mpc/operators/mpc_sgd_op.cc b/core/paddlefl_mpc/operators/mpc_sgd_op.cc index f5329bb79ebd395aecda105026270a3ce34cfe93..c2f7cc9a5ef593d79ec5078b0eb48953d82e9773 100644 --- a/core/paddlefl_mpc/operators/mpc_sgd_op.cc +++ b/core/paddlefl_mpc/operators/mpc_sgd_op.cc @@ -1,16 +1,16 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include "mpc_sgd_op.h" #include "paddle/fluid/framework/op_registry.h" @@ -20,77 +20,77 @@ namespace operators { class MpcSGDOp : public framework::OperatorWithKernel { public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Param"), - "Input(Param) of MPCSGDOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Grad"), - "Input(Grad) of MPCSGDOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("LearningRate"), - "Input(LearningRate) of MPCSGDOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), - "Output(ParamOut) of MPCSGDOp should not be null."); - - auto lr_dims = ctx->GetInputDim("LearningRate"); - PADDLE_ENFORCE_NE(framework::product(lr_dims), 0, - "Maybe the Input variable LearningRate has not " - "been initialized. You may need to confirm " - "if you put exe.run(startup_program) " - "after optimizer.minimize function."); - PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, - "Learning rate should have 1 element"); - auto param_dim = ctx->GetInputDim("Param"); - if (ctx->GetInputsVarType("Grad")[0] == - framework::proto::VarType::LOD_TENSOR) { - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Grad"), - platform::errors::InvalidArgument( - "MPCSGD Operator's input Param and Grad dimensions do not match. " - "The Param %s shape is [%s], but the Grad %s shape is [%s].", - ctx->Inputs("Param")[0], param_dim, ctx->Inputs("Grad")[0], - ctx->GetInputDim("Grad"))); - } - ctx->SetOutputDim("ParamOut", param_dim); + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(Param) of MPCSGDOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(Grad) of MPCSGDOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of MPCSGDOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(ParamOut) of MPCSGDOp should not be null."); + + auto lr_dims = ctx->GetInputDim("LearningRate"); + PADDLE_ENFORCE_NE(framework::product(lr_dims), 0, + "Maybe the Input variable LearningRate has not " + "been initialized. You may need to confirm " + "if you put exe.run(startup_program) " + "after optimizer.minimize function."); + PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, + "Learning rate should have 1 element"); + auto param_dim = ctx->GetInputDim("Param"); + if (ctx->GetInputsVarType("Grad")[0] == + framework::proto::VarType::LOD_TENSOR) { + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Grad"), + platform::errors::InvalidArgument( + "MPCSGD Operator's input Param and Grad dimensions do not match. " + "The Param %s shape is [%s], but the Grad %s shape is [%s].", + ctx->Inputs("Param")[0], param_dim, ctx->Inputs("Grad")[0], + ctx->GetInputDim("Grad"))); + } + ctx->SetOutputDim("ParamOut", param_dim); } protected: - framework::OpKernelType - GetExpectedKernelType(const framework::ExecutionContext &ctx) const override { - auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return framework::OpKernelType(data_type, ctx.device_context()); + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); + return framework::OpKernelType(data_type, ctx.device_context()); } }; class MpcSGDOpInferVarType : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext *ctx) const override { - auto &input_var_n = ctx->Input("Param")[0]; - auto in_var_type = ctx->GetType(input_var_n); - PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS || - in_var_type == framework::proto::VarType::LOD_TENSOR, - "The input Var's type should be LoDtensor or SelectedRows," - " but the received var(%s)'s type is %s", - input_var_n, in_var_type); - - for (auto &out_var_n : ctx->Output("ParamOut")) { - if (ctx->GetType(out_var_n) != in_var_type) { - ctx->SetType(out_var_n, in_var_type); - } - } + void operator()(framework::InferVarTypeContext *ctx) const override { + auto in_var_type = ctx->GetInputType("Param"); + PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS || + in_var_type == framework::proto::VarType::LOD_TENSOR, + "The input Var's type should be LoDtensor or SelectedRows," + " but the received var(%s)'s type is %s", + ctx->InputVarName("Param"), in_var_type); + ctx->SetOutputType("ParamOut", in_var_type); + + //for (auto &out_var_n : framework::StaticGraphVarTypeInference::Output(ctx, "ParamOut")) { + // if (ctx->GetVarType(out_var_n) != in_var_type) { + // ctx->SetType(out_var_n, in_var_type); + //} + //} } }; class MpcSGDOpMaker : public framework::OpProtoAndCheckerMaker { public: - void Make() override { - AddInput("Param", "(Tensor or SelectedRows) Input parameter"); - AddInput("LearningRate", "(Tensor) Learning rate of MPCSGD"); - AddInput("Grad", "(Tensor or SelectedRows) Input gradient"); - AddOutput("ParamOut", - "(Tensor or SelectedRows, same with Param) " - "Output parameter, should share the same memory with Param"); - AddComment(R"DOC( + void Make() override { + AddInput("Param", "(Tensor or SelectedRows) Input parameter"); + AddInput("LearningRate", "(Tensor) Learning rate of MPCSGD"); + AddInput("Grad", "(Tensor or SelectedRows) Input gradient"); + AddOutput("ParamOut", + "(Tensor or SelectedRows, same with Param) " + "Output parameter, should share the same memory with Param"); + AddComment(R"DOC( MPCSGD operator @@ -102,13 +102,13 @@ $$param\_out = param - learning\_rate * grad$$ } }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR( mpc_sgd, ops::MpcSGDOp, ops::MpcSGDOpMaker, - // paddle::framework::EmptyGradOpMaker, ops::MpcSGDOpInferVarType); REGISTER_OP_CPU_KERNEL( - mpc_sgd, ops::MpcSGDOpKernel); + mpc_sgd, + ops::MpcSGDOpKernel); diff --git a/core/paddlefl_mpc/operators/mpc_sgd_op.h b/core/paddlefl_mpc/operators/mpc_sgd_op.h index acdb4d17adf966ad5c530c60c87a7768303cbcac..805b74d04d38f4c8fdb040a958f77fea12e6f8ed 100644 --- a/core/paddlefl_mpc/operators/mpc_sgd_op.h +++ b/core/paddlefl_mpc/operators/mpc_sgd_op.h @@ -1,68 +1,62 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ #pragma once #include "mpc_op.h" #include "paddle/fluid/framework/eigen.h" -#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h" namespace paddle { namespace operators { template class MpcSGDOpKernel : public MpcOpKernel { -public: - void ComputeImpl(const framework::ExecutionContext &ctx) const override { - const auto *param_var = ctx.InputVar("Param"); - PADDLE_ENFORCE_EQ(param_var->IsType(), true, - "The Var(%s)'s type should be LoDTensor, " - "but the received is %s", - ctx.Inputs("Param").front(), - framework::ToTypeName(param_var->Type())); - - const auto *grad_var = ctx.InputVar("Grad"); - PADDLE_ENFORCE_EQ(grad_var->IsType(), true, - "The Var(%s)'s type should be LoDTensor, " - "but the received is %s", - ctx.Inputs("Grad").front(), - framework::ToTypeName(grad_var->Type())); - - const auto *learning_rate = ctx.Input("LearningRate"); - const auto *param = ctx.Input("Param"); - const auto *grad = ctx.Input("Grad"); - - auto *param_out = ctx.Output("ParamOut"); - - auto sz = param_out->numel(); - PADDLE_ENFORCE_EQ(param->numel(), sz); - PADDLE_ENFORCE_EQ(grad->numel(), sz); - - const double *lr = learning_rate->data(); - // const T *param_data = param->data(); - // const T *grad_data = grad->data(); - - T *out_data = param_out->mutable_data(ctx.GetPlace()); - PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, - "Protocol %s is not yet created in MPC Protocol."); - // update parameters - framework::Tensor temp; - temp.mutable_data(param->dims(), ctx.GetPlace()); - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale( - grad, lr[0], &temp); - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub( - param, &temp, param_out); - } + public: + void ComputeImpl(const framework::ExecutionContext &ctx) const override{ + const auto *param_var = ctx.InputVar("Param"); + PADDLE_ENFORCE_EQ(param_var->IsType(), true, + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.InputNames("Param").front(), + framework::ToTypeName(param_var->Type())); + + const auto *grad_var = ctx.InputVar("Grad"); + PADDLE_ENFORCE_EQ(grad_var->IsType(), true, + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.InputNames("Grad").front(), + framework::ToTypeName(grad_var->Type())); + + const auto *learning_rate = ctx.Input("LearningRate"); + const auto *param = ctx.Input("Param"); + const auto *grad = ctx.Input("Grad"); + + auto *param_out = ctx.Output("ParamOut"); + + auto sz = param_out->numel(); + PADDLE_ENFORCE_EQ(param->numel(), sz); + PADDLE_ENFORCE_EQ(grad->numel(), sz); + + const double *lr = learning_rate->data(); + + param_out->mutable_data(ctx.GetPlace()); + PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, "Protocol %s is not yet created in MPC Protocol."); + // update parameters + framework::Tensor temp; + temp.mutable_data(param->dims(), ctx.GetPlace()); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(grad, lr[0], &temp); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(param, &temp, param_out); + } }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle diff --git a/core/paddlefl_mpc/operators/mpc_sigmoid_cross_entropy_with_logits_op.cc b/core/paddlefl_mpc/operators/mpc_sigmoid_cross_entropy_with_logits_op.cc index 09035196947d2cfa20ae691102ebf60426f421e9..0ebbc82601d86ce760c2d46225e3e9a6aea7c34c 100644 --- a/core/paddlefl_mpc/operators/mpc_sigmoid_cross_entropy_with_logits_op.cc +++ b/core/paddlefl_mpc/operators/mpc_sigmoid_cross_entropy_with_logits_op.cc @@ -117,21 +117,19 @@ MpcSigmoidCrossEntropyWithLogits Operator. }; template -class MpcSigmoidCrossEntropyWithLogitsGradOpMaker : public framework::SingleGradOpDescMaker { +class MpcSigmoidCrossEntropyWithLogitsGradOpMaker : public framework::SingleGradOpMaker { public: - using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + using framework::SingleGradOpMaker::SingleGradOpMaker; protected: - std::unique_ptr Apply() const override { - std::unique_ptr retv(new T()); - retv->SetType("mpc_sigmoid_cross_entropy_with_logits_grad"); - retv->SetInput("X", this->Input("X")); - retv->SetInput("Label", this->Input("Label")); - retv->SetInput("Out", this->Output("Out")); - retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - retv->SetAttrMap(this->Attrs()); - return retv; + void Apply(GradOpPtr grad) const override { + grad->SetType("mpc_sigmoid_cross_entropy_with_logits_grad"); + grad->SetInput("X", this->Input("X")); + grad->SetInput("Label", this->Input("Label")); + grad->SetInput("Out", this->Output("Out")); + grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad->SetAttrMap(this->Attrs()); } }; diff --git a/core/paddlefl_mpc/operators/mpc_square_op.cc b/core/paddlefl_mpc/operators/mpc_square_op.cc index 76d0e28e0191e8508edd95ac0ff97fca72a411a7..4142f01dac058267b4248363185e97a083f9788f 100644 --- a/core/paddlefl_mpc/operators/mpc_square_op.cc +++ b/core/paddlefl_mpc/operators/mpc_square_op.cc @@ -1,92 +1,93 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ -#include "mpc_square_op.h" #include "paddle/fluid/framework/op_registry.h" +#include "mpc_square_op.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; + class MpcSquareOp : public framework::OperatorWithKernel { public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::NotFound( - "Input(X) of MpcSquareOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::NotFound( - "Output(Out) of MpcSquareOp should not be null.")); - ctx->ShareDim("X", /*->*/ "Out"); - ctx->ShareLoD("X", /*->*/ "Out"); - } + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::NotFound("Input(X) of MpcSquareOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + platform::errors::NotFound("Output(Out) of MpcSquareOp should not be null.")); + ctx->ShareDim("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); + } }; class MpcSquareOpMaker : public framework::OpProtoAndCheckerMaker { public: - void Make() override { - AddInput("X", "(Tensor), The first input tensor of mpc square op."); - AddOutput("Out", "(Tensor), The output tensor of mpc square op."); - AddComment(R"DOC( + void Make() override { + AddInput("X", "(Tensor), The first input tensor of mpc square op."); + AddOutput("Out", "(Tensor), The output tensor of mpc square op."); + AddComment(R"DOC( MPC square Operator.. )DOC"); - } + } }; class MpcSquareGradOp : public framework::OperatorWithKernel { public: - using framework::OperatorWithKernel::OperatorWithKernel; - using Tensor = framework::Tensor; - - void InferShape(framework::InferShapeContext *ctx) const override { - ctx->ShareDim(framework::GradVarName("Out"), framework::GradVarName("X")); - ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X")); - } + using framework::OperatorWithKernel::OperatorWithKernel; + using Tensor = framework::Tensor; + + void InferShape(framework::InferShapeContext *ctx) const override { + ctx->ShareDim(framework::GradVarName("Out"), framework::GradVarName("X")); + ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X")); + } }; template -class MpcSquareGradOpMaker : public framework::SingleGradOpDescMaker { +class MpcSquareGradOpMaker : public framework::SingleGradOpMaker { public: - using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + using framework::SingleGradOpMaker::SingleGradOpMaker; protected: - std::unique_ptr Apply() const override { - std::unique_ptr retv(new T()); - retv->SetType("mpc_square_grad"); - retv->SetInput("X", this->Input("X")); - retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - return retv; - } + void Apply(GradOpPtr grad) const override { + grad->SetType("mpc_square_grad"); + grad->SetInput("X", this->Input("X")); + grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; -REGISTER_OPERATOR(mpc_square, ops::MpcSquareOp, ops::MpcSquareOpMaker, - ops::MpcSquareGradOpMaker); +REGISTER_OPERATOR(mpc_square, ops::MpcSquareOp, + ops::MpcSquareOpMaker, + ops::MpcSquareGradOpMaker); -REGISTER_OPERATOR(mpc_square_grad, ops::MpcSquareGradOp); +REGISTER_OPERATOR(mpc_square_grad, ops::MpcSquareGradOp); REGISTER_OP_CPU_KERNEL( - mpc_square, + mpc_square, ops::MpcSquareKernel); REGISTER_OP_CPU_KERNEL( - mpc_square_grad, + mpc_square_grad, ops::MpcSquareGradKernel); diff --git a/core/paddlefl_mpc/operators/mpc_square_op.h b/core/paddlefl_mpc/operators/mpc_square_op.h index c42e4f27b026a7ea6d20bf94a010501ece77a1b8..488b93eca4747bfdfa4e464fec6a23dd776fb6fa 100644 --- a/core/paddlefl_mpc/operators/mpc_square_op.h +++ b/core/paddlefl_mpc/operators/mpc_square_op.h @@ -1,19 +1,19 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #pragma once #include "mpc_op.h" -#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h" namespace paddle { namespace operators { @@ -23,33 +23,31 @@ using Tensor = framework::Tensor; template class MpcSquareKernel : public MpcOpKernel { public: - void ComputeImpl(const framework::ExecutionContext &ctx) const override { - auto *in_x_t = ctx.Input("X"); - auto *out_t = ctx.Output("Out"); - out_t->mutable_data(ctx.GetPlace()); - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul( - in_x_t, in_x_t, out_t); - } + void ComputeImpl(const framework::ExecutionContext &ctx) const override { + auto *in_x_t = ctx.Input("X"); + auto *out_t = ctx.Output("Out"); + out_t->mutable_data(ctx.GetPlace()); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul(in_x_t, in_x_t, out_t); + } }; template class MpcSquareGradKernel : public MpcOpKernel { public: - void ComputeImpl(const framework::ExecutionContext &ctx) const override { - auto *in_x_t = ctx.Input("X"); - auto *dout_t = ctx.Input(framework::GradVarName("Out")); - auto *dx_t = ctx.Output(framework::GradVarName("X")); - if (dx_t != nullptr) { - // allocate memory on device. - dx_t->mutable_data(ctx.GetPlace()); - // dx = dout * 2 * x - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale( - in_x_t, 2.0, dx_t); - mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul( - dx_t, dout_t, dx_t); + void ComputeImpl(const framework::ExecutionContext &ctx) const override { + auto *in_x_t = ctx.Input("X"); + auto *dout_t = ctx.Input(framework::GradVarName("Out")); + auto *dx_t = ctx.Output(framework::GradVarName("X")); + if (dx_t != nullptr) { + // allocate memory on device. + dx_t->mutable_data(ctx.GetPlace()); + // dx = dout * 2 * x + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(in_x_t, 2.0, dx_t); + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul(dx_t, dout_t, dx_t); + } } - } }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle + diff --git a/core/paddlefl_mpc/operators/mpc_sum_op.cc b/core/paddlefl_mpc/operators/mpc_sum_op.cc index 99d9a721b8dd3056ab14e2001594ca2a213cd7f0..df12d7638f57a9c99855c8d1bc16072f94a3b1eb 100644 --- a/core/paddlefl_mpc/operators/mpc_sum_op.cc +++ b/core/paddlefl_mpc/operators/mpc_sum_op.cc @@ -1,26 +1,25 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include #include #include #include #include -#include "mpc_sum_op.h" -#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/var_type_inference.h" +#include "paddle/fluid/framework/op_registry.h" +#include "mpc_sum_op.h" namespace paddle { namespace operators { @@ -29,131 +28,135 @@ using Tensor = framework::Tensor; class MpcSumOp : public framework::OperatorWithKernel { public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInputs("X"), true, - platform::errors::NotFound( - "Input(X) of MpcElementwiseAddOp should not be null.")); - - PADDLE_ENFORCE_EQ( - ctx->HasOutput("Out"), true, - platform::errors::NotFound("Output(Out) of MulOp should not be null.")); - - auto x_var_types = ctx->GetInputsVarType("X"); - auto x_dims = ctx->GetInputsDim("X"); - auto N = x_dims.size(); - PADDLE_ENFORCE_GT( - N, 0, "ShapeError: The input tensor X's dimensions of SumOp " - "should be larger than 0. But received X's dimensions %d, " - "X's shape = [%s].", - N, &x_dims); - if (N == 1) { - VLOG(3) << "Warning: SumOp have only one input, may waste memory"; - } - - framework::DDim in_dim({0}); - for (size_t i = 0; i < x_dims.size(); ++i) { - auto &x_dim = x_dims[i]; - // x_dim.size() == 1 means the real dim of selected rows is [0] - if (x_var_types[i] == framework::proto::VarType::SELECTED_ROWS && - x_dim.size() == 1) { - continue; - } - if (framework::product(x_dim) == 0) { - continue; - } - if (framework::product(in_dim) == 0) { - in_dim = x_dim; - } else { - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ( - in_dim, x_dim, - "ShapeError: The input tensor X of SumOp must have same shape." - "But received X[0]'s shape = [%s], X[%d]'s shape = [%s].", - in_dim, i, x_dim); - } else { - PADDLE_ENFORCE_EQ( - in_dim.size(), x_dim.size(), - "ShapeError: The input tensor X of SumOp must have same " - "dimensions. But received X[0]'s dimensions = %d, X[0]'s shape = " - "[%s], X[%d]'s dimensions = %d, X[%d]'s shape = [%s].", - in_dim.size(), in_dim, i, x_dim.size(), i, x_dim); - // if in_dim or x_dim has -1, not check equal - for (int j = 0; j < x_dim.size(); ++j) { - if (x_dim[j] == -1 || in_dim[j] == -1) { - continue; + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInputs("X"), true, + platform::errors::NotFound("Input(X) of MpcElementwiseAddOp should not be null.")); + + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + platform::errors::NotFound("Output(Out) of MulOp should not be null.")); + + auto x_var_types = ctx->GetInputsVarType("X"); + auto x_dims = ctx->GetInputsDim("X"); + auto N = x_dims.size(); + PADDLE_ENFORCE_GT( + N, 0, + "ShapeError: The input tensor X's dimensions of SumOp " + "should be larger than 0. But received X's dimensions %d, " + "X's shape = [%s].", + N, &x_dims); + if (N == 1) { + VLOG(3) << "Warning: SumOp have only one input, may waste memory"; + } + + framework::DDim in_dim({0}); + for (size_t i = 0; i < x_dims.size(); ++i) { + auto& x_dim = x_dims[i]; + // x_dim.size() == 1 means the real dim of selected rows is [0] + if (x_var_types[i] == framework::proto::VarType::SELECTED_ROWS && + x_dim.size() == 1) { + continue; } - PADDLE_ENFORCE_EQ( - in_dim[j], x_dim[j], - "ShapeError: The input tensor X of SumOp must have same shape " - "if not -1." - "But received X[0]'s shape = [%s], X[%d]'s shape = [%s].", - in_dim, i, x_dim); - } + if (framework::product(x_dim) == 0) { + continue; + } + if (framework::product(in_dim) == 0) { + in_dim = x_dim; + } else { + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ( + in_dim, x_dim, + "ShapeError: The input tensor X of SumOp must have same shape." + "But received X[0]'s shape = [%s], X[%d]'s shape = [%s].", + in_dim, i, x_dim); + } else { + PADDLE_ENFORCE_EQ( + in_dim.size(), x_dim.size(), + "ShapeError: The input tensor X of SumOp must have same " + "dimensions. But received X[0]'s dimensions = %d, X[0]'s shape = " + "[%s], X[%d]'s dimensions = %d, X[%d]'s shape = [%s].", + in_dim.size(), in_dim, i, x_dim.size(), i, x_dim); + // if in_dim or x_dim has -1, not check equal + for (int j = 0; j < x_dim.size(); ++j) { + if (x_dim[j] == -1 || in_dim[j] == -1) { + continue; + } + PADDLE_ENFORCE_EQ( + in_dim[j], x_dim[j], + "ShapeError: The input tensor X of SumOp must have same shape " + "if not -1." + "But received X[0]'s shape = [%s], X[%d]'s shape = [%s].", + in_dim, i, x_dim); + } + } + } } - } + + ctx->SetOutputDim("Out", in_dim); + ctx->ShareLoD("X", /*->*/ "Out"); } - ctx->SetOutputDim("Out", in_dim); - ctx->ShareLoD("X", /*->*/ "Out"); - } }; class MpcSumOpMaker : public framework::OpProtoAndCheckerMaker { public: - void Make() override { - AddInput("X", + void Make() override { + AddInput("X", "A Varaible list. The shape and data type of the list elements" "should be consistent. Variable can be multi-dimensional Tensor" "or LoDTensor, and data types can be: float32, float64, int32, " "int64.") - .AsDuplicable(); - AddOutput("Out", "the sum of input :code:`x`. its shape and data types are " - "consistent with :code:`x`."); - AddAttr("use_mkldnn", + .AsDuplicable(); + AddOutput("Out", + "the sum of input :code:`x`. its shape and data types are " + "consistent with :code:`x`."); + AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); - AddComment(R"DOC(This OP is used to sum one or more Tensor or LoDTensor + AddComment(R"DOC(This OP is used to sum one or more Tensor or LoDTensor of the input. If the input is LoDTensor, the output only shares LoD information with the first input.)DOC"); - } + } }; + class MpcSumGradMaker : public framework::GradOpDescMakerBase { public: - using framework::GradOpDescMakerBase::GradOpDescMakerBase; - - std::vector> operator()() const override { - auto x_grads = InputGrad("X", false); - std::vector> grad_ops; - grad_ops.reserve(x_grads.size()); - auto og = OutputGrad("Out"); - std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops), - [&og](const std::string &x_grad) { - auto *grad_op = new framework::OpDesc(); - grad_op->SetType("scale"); - grad_op->SetInput("X", og); - grad_op->SetOutput("Out", {x_grad}); - grad_op->SetAttr("scale", 1.0f); - return std::unique_ptr(grad_op); - }); - - return grad_ops; - } + using framework::GradOpDescMakerBase::GradOpDescMakerBase; + + std::vector> operator()() const override { + auto x_grads = InputGrad("X", false); + std::vector> grad_ops; + grad_ops.reserve(x_grads.size()); + auto og = OutputGrad("Out"); + std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops), + [&og](const std::string& x_grad) { + auto* grad_op = new framework::OpDesc(); + grad_op->SetType("scale"); + grad_op->SetInput("X", og); + grad_op->SetOutput("Out", {x_grad}); + grad_op->SetAttr("scale", 1.0f); + return std::unique_ptr(grad_op); + }); + + return grad_ops; + } }; DECLARE_INPLACE_OP_INFERER(MpcSumInplace, {"X", "Out"}); -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; -// REGISTER_OP_WITHOUT_GRADIENT(mpc_sum, ops::MpcSumOp, ops::MpcSumOpMaker); -REGISTER_OPERATOR(mpc_sum, ops::MpcSumOp, ops::MpcSumOpMaker, - ops::MpcSumGradMaker, ops::MpcSumInplace); +//REGISTER_OP_WITHOUT_GRADIENT(mpc_sum, ops::MpcSumOp, ops::MpcSumOpMaker); +REGISTER_OPERATOR(mpc_sum, ops::MpcSumOp, + ops::MpcSumOpMaker, + ops::MpcSumGradMaker, + ops::MpcSumInplace); -REGISTER_OP_CPU_KERNEL( - mpc_sum, ops::MpcSumKernel); +REGISTER_OP_CPU_KERNEL(mpc_sum, ops::MpcSumKernel); diff --git a/core/paddlefl_mpc/operators/mpc_sum_op.h b/core/paddlefl_mpc/operators/mpc_sum_op.h index d00e33d29a609e0eaa1d3cf61793c0fe971df1ce..ae7ef56fe9e4015b1ee2c9e248e424051942138d 100644 --- a/core/paddlefl_mpc/operators/mpc_sum_op.h +++ b/core/paddlefl_mpc/operators/mpc_sum_op.h @@ -1,16 +1,16 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once #include "mpc_op.h" @@ -23,62 +23,57 @@ using Tensor = framework::Tensor; template class MpcSumKernel : public MpcOpKernel { public: - void ComputeImpl(const framework::ExecutionContext &ctx) const override { - auto in_vars = ctx.MultiInputVar("X"); - size_t in_num = in_vars.size(); - auto out_var = ctx.OutputVar("Out"); - bool in_place = out_var == in_vars[0]; + void ComputeImpl(const framework::ExecutionContext &ctx) const override { + auto in_vars = ctx.MultiInputVar("X"); + size_t in_num = in_vars.size(); + auto out_var = ctx.OutputVar("Out"); + bool in_place = out_var == in_vars[0]; - if (out_var->IsType()) { - auto *out = out_var->GetMutable(); - auto *out_ptr = out->mutable_data(ctx.GetPlace()); - if (in_num >= 1 && in_vars[0]->IsType()) { - auto &in_0_tensor = in_vars[0]->Get(); - if (in_0_tensor.numel() > 0) { - in_place = (in_0_tensor.data() == out_ptr); - } - } - int start = in_place ? 1 : 0; - if (!in_place) { - if ((in_num >= 2) && in_vars[0]->IsType() && - in_vars[1]->IsType()) { - auto &in_0 = in_vars[0]->Get(); - auto &in_1 = in_vars[1]->Get(); - if (in_0.numel() && in_1.numel()) { - mpc::MpcInstance::mpc_instance() - ->mpc_protocol() - ->mpc_operators() - ->add(&in_0, &in_1, out); - start = 2; - } - } - if (start != 2) { - auto t = framework::EigenVector::Flatten(*out); - auto &device_ctx = ctx.template device_context(); - t.device(*device_ctx.eigen_device()) = t.constant(static_cast(0)); - } - } + if (out_var->IsType()) { + auto *out = out_var->GetMutable(); + auto *out_ptr = out->mutable_data(ctx.GetPlace()); + if (in_num >= 1 && in_vars[0]->IsType()) { + auto &in_0_tensor = in_vars[0]->Get(); + if (in_0_tensor.numel() > 0) { + in_place = (in_0_tensor.data() == out_ptr); + } + } + int start = in_place ? 1 : 0; + if (!in_place) { + if ((in_num >= 2) && in_vars[0]->IsType() && + in_vars[1]->IsType()) { + auto &in_0 = in_vars[0]->Get(); + auto &in_1 = in_vars[1]->Get(); + if (in_0.numel() && in_1.numel()) { + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(&in_0, &in_1, out); + start = 2; + } + } + if (start != 2) { + auto t = framework::EigenVector::Flatten(*out); + auto &device_ctx = ctx.template device_context(); + t.device(*device_ctx.eigen_device()) = t.constant(static_cast(0)); + } + } - // If in_place, just skip the first tensor - for (size_t i = start; i < in_num; i++) { - if (in_vars[i]->IsType()) { - auto &in_t = in_vars[i]->Get(); - if (in_t.numel() == 0) { - continue; - } - mpc::MpcInstance::mpc_instance() - ->mpc_protocol() - ->mpc_operators() - ->add(out, &in_t, out); - } else { - PADDLE_THROW("Variable type must be LoDTensor/SelectedRows."); + // If in_place, just skip the first tensor + for (size_t i = start; i < in_num; i++) { + if (in_vars[i]->IsType()) { + auto &in_t = in_vars[i]->Get(); + if (in_t.numel() == 0) { + continue; + } + mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(out, &in_t, out); + } else { + PADDLE_THROW("Variable type must be LoDTensor/SelectedRows."); + } + } + }else { + PADDLE_THROW("Unexpected branch, output variable type is %s", + framework::ToTypeName(out_var->Type())); } - } - } else { - PADDLE_THROW("Unexpected branch, output variable type is %s", - framework::ToTypeName(out_var->Type())); } - } }; -} // namespace operators -} // namespace paddle +} // namespace operators +} // namespace paddle + diff --git a/core/privc3/boolean_tensor_test.cc b/core/privc3/boolean_tensor_test.cc index 79d832b85bb9ba05d84b14416cea00937fba7da8..c44d6cd0e2d92a44df4e407427aba9f63b09a304 100644 --- a/core/privc3/boolean_tensor_test.cc +++ b/core/privc3/boolean_tensor_test.cc @@ -16,19 +16,19 @@ #include #include -#include "paddle/fluid/framework/op_info.h" -#include "paddle/fluid/framework/op_registry.h" +#include "gtest/gtest.h" +#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/platform/init.h" -#include "gtest/gtest.h" #include "boolean_tensor.h" -#include "circuit_context.h" #include "fixedpoint_tensor.h" -#include "core/paddlefl_mpc/mpc_protocol/mesh_network.h" #include "paddle_tensor.h" +#include "circuit_context.h" +#include "core/paddlefl_mpc/mpc_protocol/mesh_network.h" namespace aby3 { @@ -36,1158 +36,1298 @@ using paddle::framework::Tensor; class BooleanTensorTest : public ::testing::Test { public: - paddle::platform::CPUDeviceContext _cpu_ctx; + paddle::platform::CPUDeviceContext _cpu_ctx; - std::shared_ptr _exec_ctx; - std::shared_ptr _mpc_ctx[3]; + std::shared_ptr _exec_ctx; + std::shared_ptr _mpc_ctx[3]; - std::shared_ptr _store; + std::shared_ptr _store; - std::thread _t[3]; + std::thread _t[3]; - std::shared_ptr _tensor_factory; + std::shared_ptr _tensor_factory; - void SetUp() { - paddle::framework::OperatorBase *op = nullptr; - paddle::framework::Scope scope; - paddle::framework::RuntimeContext ctx({}, {}); - // only device_ctx is needed - _exec_ctx = std::make_shared( - *op, scope, _cpu_ctx, ctx, nullptr); + virtual ~BooleanTensorTest() noexcept {} - _store = std::make_shared(); + void SetUp() { + paddle::framework::OperatorBase* op = nullptr; + paddle::framework::Scope scope; + paddle::framework::RuntimeContext ctx({}, {}); + // only device_ctx is needed + _exec_ctx = std::make_shared( + *op, scope, _cpu_ctx, ctx); - std::thread t[3]; - for (size_t i = 0; i < 3; ++i) { - _t[i] = std::thread(&BooleanTensorTest::gen_mpc_ctx, this, i); - } + _store = std::make_shared(); - for (auto &ti : _t) { - ti.join(); - } + std::thread t[3]; + for (size_t i = 0; i < 3; ++i) { + _t[i] = std::thread(&BooleanTensorTest::gen_mpc_ctx, this, i); + } - _tensor_factory = std::make_shared(&_cpu_ctx); - } + for (auto& ti : _t) { + ti.join(); + } + + _tensor_factory = std::make_shared(&_cpu_ctx); + } - std::shared_ptr gen_network(size_t idx) { + std::shared_ptr gen_network(size_t idx) { - return std::make_shared(idx, "127.0.0.1", 3, - "test_prefix", _store); - } + return std::make_shared(idx, + "127.0.0.1", + 3, + "test_prefix", + _store); + } - void gen_mpc_ctx(size_t idx) { - auto net = gen_network(idx); - net->init(); - _mpc_ctx[idx] = std::make_shared(idx, net); - } + void gen_mpc_ctx(size_t idx) { + auto net = gen_network(idx); + net->init(); + _mpc_ctx[idx] = std::make_shared(idx, net); + } - std::shared_ptr> gen1() { - return _tensor_factory->template create(std::vector{1}); - } + std::shared_ptr> gen1() { + return _tensor_factory->template create(std::vector{1}); + } - std::shared_ptr> - gen(const std::vector &shape) { - return _tensor_factory->template create(shape); - } + std::shared_ptr> gen(const std::vector& shape) { + return _tensor_factory->template create(shape); + } }; using paddle::mpc::ContextHolder; TEST_F(BooleanTensorTest, empty_test) { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - []() { ; }); + ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], [](){ ; }); } using BTensor = BooleanTensor; TEST_F(BooleanTensorTest, reveal1_test) { - std::shared_ptr> s[3] = {gen1(), gen1(), gen1()}; - auto p = gen1(); - s[0]->data()[0] = 2; - s[1]->data()[0] = 3; - s[2]->data()[0] = 4; - - BTensor b0(s[0].get(), s[1].get()); - BTensor b1(s[1].get(), s[2].get()); - BTensor b2(s[2].get(), s[0].get()); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context( - _exec_ctx.get(), _mpc_ctx[0], [&]() { b0.reveal_to_one(0, p.get()); }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context( - _exec_ctx.get(), _mpc_ctx[1], [&]() { b1.reveal_to_one(0, nullptr); }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context( - _exec_ctx.get(), _mpc_ctx[2], [&]() { b2.reveal_to_one(0, nullptr); }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(2 ^ 3 ^ 4, p->data()[0]); + std::shared_ptr> s[3] = { gen1(), gen1(), gen1() }; + auto p = gen1(); + s[0]->data()[0] = 2; + s[1]->data()[0] = 3; + s[2]->data()[0] = 4; + + BTensor b0(s[0].get(), s[1].get()); + BTensor b1(s[1].get(), s[2].get()); + BTensor b2(s[2].get(), s[0].get()); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + b0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + b1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + b2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(2 ^ 3 ^ 4, p->data()[0]); } TEST_F(BooleanTensorTest, reveal2_test) { - std::shared_ptr> s[3] = {gen1(), gen1(), gen1()}; - std::shared_ptr> p[3] = {gen1(), gen1(), gen1()}; - s[0]->data()[0] = 2; - s[1]->data()[0] = 3; - s[2]->data()[0] = 4; - - BTensor b0(s[0].get(), s[1].get()); - BTensor b1(s[1].get(), s[2].get()); - BTensor b2(s[2].get(), s[0].get()); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - [&]() { b0.reveal(p[0].get()); }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[1], - [&]() { b1.reveal(p[1].get()); }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[2], - [&]() { b2.reveal(p[2].get()); }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(2 ^ 3 ^ 4, p[0]->data()[0]); - EXPECT_EQ(2 ^ 3 ^ 4, p[1]->data()[0]); - EXPECT_EQ(2 ^ 3 ^ 4, p[2]->data()[0]); + std::shared_ptr> s[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> p[3] = { gen1(), gen1(), gen1() }; + s[0]->data()[0] = 2; + s[1]->data()[0] = 3; + s[2]->data()[0] = 4; + + BTensor b0(s[0].get(), s[1].get()); + BTensor b1(s[1].get(), s[2].get()); + BTensor b2(s[2].get(), s[0].get()); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + b0.reveal(p[0].get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + b1.reveal(p[1].get()); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + b2.reveal(p[2].get()); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(2 ^ 3 ^ 4, p[0]->data()[0]); + EXPECT_EQ(2 ^ 3 ^ 4, p[1]->data()[0]); + EXPECT_EQ(2 ^ 3 ^ 4, p[2]->data()[0]); } TEST_F(BooleanTensorTest, xor1_test) { - std::shared_ptr> sl[3] = {gen1(), gen1(), gen1()}; - std::shared_ptr> sr[3] = {gen1(), gen1(), gen1()}; - std::shared_ptr> sout[6] = {gen1(), gen1(), gen1(), - gen1(), gen1(), gen1()}; - - // lhs = 5 = 2 ^ 3 ^ 4 - sl[0]->data()[0] = 2; - sl[1]->data()[0] = 3; - sl[2]->data()[0] = 4; - - // rhs = 7 = 1 ^ 2 ^ 4 - sr[0]->data()[0] = 1; - sr[1]->data()[0] = 2; - sr[2]->data()[0] = 4; - - auto p = gen1(); - - BTensor bl0(sl[0].get(), sl[1].get()); - BTensor bl1(sl[1].get(), sl[2].get()); - BTensor bl2(sl[2].get(), sl[0].get()); - - BTensor br0(sr[0].get(), sr[1].get()); - BTensor br1(sr[1].get(), sr[2].get()); - BTensor br2(sr[2].get(), sr[0].get()); - - BTensor bout0(sout[0].get(), sout[1].get()); - BTensor bout1(sout[2].get(), sout[3].get()); - BTensor bout2(sout[4].get(), sout[5].get()); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - [&]() { - bl0.bitwise_xor(&br0, &bout0); - bout0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[1], - [&]() { - bl1.bitwise_xor(&br1, &bout1); - bout1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[2], - [&]() { - bl2.bitwise_xor(&br2, &bout2); - bout2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(5 ^ 7, p->data()[0]); + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sr[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1() }; + + // lhs = 5 = 2 ^ 3 ^ 4 + sl[0]->data()[0] = 2; + sl[1]->data()[0] = 3; + sl[2]->data()[0] = 4; + + // rhs = 7 = 1 ^ 2 ^ 4 + sr[0]->data()[0] = 1; + sr[1]->data()[0] = 2; + sr[2]->data()[0] = 4; + + auto p = gen1(); + + BTensor bl0(sl[0].get(), sl[1].get()); + BTensor bl1(sl[1].get(), sl[2].get()); + BTensor bl2(sl[2].get(), sl[0].get()); + + BTensor br0(sr[0].get(), sr[1].get()); + BTensor br1(sr[1].get(), sr[2].get()); + BTensor br2(sr[2].get(), sr[0].get()); + + BTensor bout0(sout[0].get(), sout[1].get()); + BTensor bout1(sout[2].get(), sout[3].get()); + BTensor bout2(sout[4].get(), sout[5].get()); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + bl0.bitwise_xor(&br0, &bout0); + bout0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + bl1.bitwise_xor(&br1, &bout1); + bout1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + bl2.bitwise_xor(&br2, &bout2); + bout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(5 ^ 7, p->data()[0]); } TEST_F(BooleanTensorTest, xor2_test) { - std::shared_ptr> sl[3] = {gen1(), gen1(), gen1()}; - - auto pr = gen1(); - - std::shared_ptr> sout[6] = {gen1(), gen1(), gen1(), - gen1(), gen1(), gen1()}; - - // lhs = 5 = 2 ^ 3 ^ 4 - sl[0]->data()[0] = 2; - sl[1]->data()[0] = 3; - sl[2]->data()[0] = 4; - - // rhs = 7 - pr->data()[0] = 7; - - auto p = gen1(); - - BTensor bl0(sl[0].get(), sl[1].get()); - BTensor bl1(sl[1].get(), sl[2].get()); - BTensor bl2(sl[2].get(), sl[0].get()); - - BTensor bout0(sout[0].get(), sout[1].get()); - BTensor bout1(sout[2].get(), sout[3].get()); - BTensor bout2(sout[4].get(), sout[5].get()); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context( - _exec_ctx.get(), _mpc_ctx[0], [&]() { - bl0.bitwise_xor(pr.get(), &bout0); - bout0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context( - _exec_ctx.get(), _mpc_ctx[1], [&]() { - bl1.bitwise_xor(pr.get(), &bout1); - bout1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context( - _exec_ctx.get(), _mpc_ctx[2], [&]() { - bl2.bitwise_xor(pr.get(), &bout2); - bout2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(5 ^ 7, p->data()[0]); + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + + auto pr = gen1(); + + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1() }; + + // lhs = 5 = 2 ^ 3 ^ 4 + sl[0]->data()[0] = 2; + sl[1]->data()[0] = 3; + sl[2]->data()[0] = 4; + + // rhs = 7 + pr->data()[0] = 7; + + auto p = gen1(); + + BTensor bl0(sl[0].get(), sl[1].get()); + BTensor bl1(sl[1].get(), sl[2].get()); + BTensor bl2(sl[2].get(), sl[0].get()); + + BTensor bout0(sout[0].get(), sout[1].get()); + BTensor bout1(sout[2].get(), sout[3].get()); + BTensor bout2(sout[4].get(), sout[5].get()); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + bl0.bitwise_xor(pr.get(), &bout0); + bout0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + bl1.bitwise_xor(pr.get(), &bout1); + bout1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + bl2.bitwise_xor(pr.get(), &bout2); + bout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(5 ^ 7, p->data()[0]); } TEST_F(BooleanTensorTest, and1_test) { - std::shared_ptr> sl[3] = {gen1(), gen1(), gen1()}; - std::shared_ptr> sr[3] = {gen1(), gen1(), gen1()}; - std::shared_ptr> sout[6] = {gen1(), gen1(), gen1(), - gen1(), gen1(), gen1()}; - - // lhs = 5 = 2 ^ 3 ^ 4 - sl[0]->data()[0] = 2; - sl[1]->data()[0] = 3; - sl[2]->data()[0] = 4; - - // rhs = 7 = 1 ^ 2 ^ 4 - sr[0]->data()[0] = 1; - sr[1]->data()[0] = 2; - sr[2]->data()[0] = 4; - - auto p = gen1(); - - BTensor bl0(sl[0].get(), sl[1].get()); - BTensor bl1(sl[1].get(), sl[2].get()); - BTensor bl2(sl[2].get(), sl[0].get()); - - BTensor br0(sr[0].get(), sr[1].get()); - BTensor br1(sr[1].get(), sr[2].get()); - BTensor br2(sr[2].get(), sr[0].get()); - - BTensor bout0(sout[0].get(), sout[1].get()); - BTensor bout1(sout[2].get(), sout[3].get()); - BTensor bout2(sout[4].get(), sout[5].get()); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - [&]() { - bl0.bitwise_and(&br0, &bout0); - bout0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[1], - [&]() { - bl1.bitwise_and(&br1, &bout1); - bout1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[2], - [&]() { - bl2.bitwise_and(&br2, &bout2); - bout2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(5 & 7, p->data()[0]); + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sr[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1() }; + + // lhs = 5 = 2 ^ 3 ^ 4 + sl[0]->data()[0] = 2; + sl[1]->data()[0] = 3; + sl[2]->data()[0] = 4; + + // rhs = 7 = 1 ^ 2 ^ 4 + sr[0]->data()[0] = 1; + sr[1]->data()[0] = 2; + sr[2]->data()[0] = 4; + + auto p = gen1(); + + BTensor bl0(sl[0].get(), sl[1].get()); + BTensor bl1(sl[1].get(), sl[2].get()); + BTensor bl2(sl[2].get(), sl[0].get()); + + BTensor br0(sr[0].get(), sr[1].get()); + BTensor br1(sr[1].get(), sr[2].get()); + BTensor br2(sr[2].get(), sr[0].get()); + + BTensor bout0(sout[0].get(), sout[1].get()); + BTensor bout1(sout[2].get(), sout[3].get()); + BTensor bout2(sout[4].get(), sout[5].get()); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + bl0.bitwise_and(&br0, &bout0); + bout0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + bl1.bitwise_and(&br1, &bout1); + bout1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + bl2.bitwise_and(&br2, &bout2); + bout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(5 & 7, p->data()[0]); } TEST_F(BooleanTensorTest, and2_test) { - std::shared_ptr> sl[3] = {gen1(), gen1(), gen1()}; - - auto pr = gen1(); - - std::shared_ptr> sout[6] = {gen1(), gen1(), gen1(), - gen1(), gen1(), gen1()}; - - // lhs = 5 = 2 ^ 3 ^ 4 - sl[0]->data()[0] = 2; - sl[1]->data()[0] = 3; - sl[2]->data()[0] = 4; - - // rhs = 7 - pr->data()[0] = 7; - - auto p = gen1(); - - BTensor bl0(sl[0].get(), sl[1].get()); - BTensor bl1(sl[1].get(), sl[2].get()); - BTensor bl2(sl[2].get(), sl[0].get()); - - BTensor bout0(sout[0].get(), sout[1].get()); - BTensor bout1(sout[2].get(), sout[3].get()); - BTensor bout2(sout[4].get(), sout[5].get()); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context( - _exec_ctx.get(), _mpc_ctx[0], [&]() { - bl0.bitwise_and(pr.get(), &bout0); - bout0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context( - _exec_ctx.get(), _mpc_ctx[1], [&]() { - bl1.bitwise_and(pr.get(), &bout1); - bout1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context( - _exec_ctx.get(), _mpc_ctx[2], [&]() { - bl2.bitwise_and(pr.get(), &bout2); - bout2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(5 & 7, p->data()[0]); + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + + auto pr = gen1(); + + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1() }; + + // lhs = 5 = 2 ^ 3 ^ 4 + sl[0]->data()[0] = 2; + sl[1]->data()[0] = 3; + sl[2]->data()[0] = 4; + + // rhs = 7 + pr->data()[0] = 7; + + auto p = gen1(); + + BTensor bl0(sl[0].get(), sl[1].get()); + BTensor bl1(sl[1].get(), sl[2].get()); + BTensor bl2(sl[2].get(), sl[0].get()); + + BTensor bout0(sout[0].get(), sout[1].get()); + BTensor bout1(sout[2].get(), sout[3].get()); + BTensor bout2(sout[4].get(), sout[5].get()); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + bl0.bitwise_and(pr.get(), &bout0); + bout0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + bl1.bitwise_and(pr.get(), &bout1); + bout1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + bl2.bitwise_and(pr.get(), &bout2); + bout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(5 & 7, p->data()[0]); } TEST_F(BooleanTensorTest, or1_test) { - std::shared_ptr> sl[3] = {gen1(), gen1(), gen1()}; - std::shared_ptr> sr[3] = {gen1(), gen1(), gen1()}; - std::shared_ptr> sout[6] = {gen1(), gen1(), gen1(), - gen1(), gen1(), gen1()}; - - // lhs = 5 = 2 ^ 3 ^ 4 - sl[0]->data()[0] = 2; - sl[1]->data()[0] = 3; - sl[2]->data()[0] = 4; - - // rhs = 7 = 1 ^ 2 ^ 4 - sr[0]->data()[0] = 1; - sr[1]->data()[0] = 2; - sr[2]->data()[0] = 4; - - auto p = gen1(); - - BTensor bl0(sl[0].get(), sl[1].get()); - BTensor bl1(sl[1].get(), sl[2].get()); - BTensor bl2(sl[2].get(), sl[0].get()); - - BTensor br0(sr[0].get(), sr[1].get()); - BTensor br1(sr[1].get(), sr[2].get()); - BTensor br2(sr[2].get(), sr[0].get()); - - BTensor bout0(sout[0].get(), sout[1].get()); - BTensor bout1(sout[2].get(), sout[3].get()); - BTensor bout2(sout[4].get(), sout[5].get()); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - [&]() { - bl0.bitwise_or(&br0, &bout0); - bout0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[1], - [&]() { - bl1.bitwise_or(&br1, &bout1); - bout1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[2], - [&]() { - bl2.bitwise_or(&br2, &bout2); - bout2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(5 | 7, p->data()[0]); + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sr[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1() }; + + // lhs = 5 = 2 ^ 3 ^ 4 + sl[0]->data()[0] = 2; + sl[1]->data()[0] = 3; + sl[2]->data()[0] = 4; + + // rhs = 7 = 1 ^ 2 ^ 4 + sr[0]->data()[0] = 1; + sr[1]->data()[0] = 2; + sr[2]->data()[0] = 4; + + auto p = gen1(); + + BTensor bl0(sl[0].get(), sl[1].get()); + BTensor bl1(sl[1].get(), sl[2].get()); + BTensor bl2(sl[2].get(), sl[0].get()); + + BTensor br0(sr[0].get(), sr[1].get()); + BTensor br1(sr[1].get(), sr[2].get()); + BTensor br2(sr[2].get(), sr[0].get()); + + BTensor bout0(sout[0].get(), sout[1].get()); + BTensor bout1(sout[2].get(), sout[3].get()); + BTensor bout2(sout[4].get(), sout[5].get()); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + bl0.bitwise_or(&br0, &bout0); + bout0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + bl1.bitwise_or(&br1, &bout1); + bout1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + bl2.bitwise_or(&br2, &bout2); + bout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(5 | 7, p->data()[0]); } TEST_F(BooleanTensorTest, or2_test) { - std::shared_ptr> sl[3] = {gen1(), gen1(), gen1()}; - - auto pr = gen1(); - - std::shared_ptr> sout[6] = {gen1(), gen1(), gen1(), - gen1(), gen1(), gen1()}; - - // lhs = 5 = 2 ^ 3 ^ 4 - sl[0]->data()[0] = 2; - sl[1]->data()[0] = 3; - sl[2]->data()[0] = 4; - - // rhs = 7 - pr->data()[0] = 7; - - auto p = gen1(); - - BTensor bl0(sl[0].get(), sl[1].get()); - BTensor bl1(sl[1].get(), sl[2].get()); - BTensor bl2(sl[2].get(), sl[0].get()); - - BTensor bout0(sout[0].get(), sout[1].get()); - BTensor bout1(sout[2].get(), sout[3].get()); - BTensor bout2(sout[4].get(), sout[5].get()); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - [&]() { - bl0.bitwise_or(pr.get(), &bout0); - bout0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[1], - [&]() { - bl1.bitwise_or(pr.get(), &bout1); - bout1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[2], - [&]() { - bl2.bitwise_or(pr.get(), &bout2); - bout2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(5 | 7, p->data()[0]); + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + + auto pr = gen1(); + + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1() }; + + // lhs = 5 = 2 ^ 3 ^ 4 + sl[0]->data()[0] = 2; + sl[1]->data()[0] = 3; + sl[2]->data()[0] = 4; + + // rhs = 7 + pr->data()[0] = 7; + + auto p = gen1(); + + BTensor bl0(sl[0].get(), sl[1].get()); + BTensor bl1(sl[1].get(), sl[2].get()); + BTensor bl2(sl[2].get(), sl[0].get()); + + BTensor bout0(sout[0].get(), sout[1].get()); + BTensor bout1(sout[2].get(), sout[3].get()); + BTensor bout2(sout[4].get(), sout[5].get()); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + bl0.bitwise_or(pr.get(), &bout0); + bout0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + bl1.bitwise_or(pr.get(), &bout1); + bout1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + bl2.bitwise_or(pr.get(), &bout2); + bout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(5 | 7, p->data()[0]); } TEST_F(BooleanTensorTest, not_test) { - std::shared_ptr> sl[3] = {gen1(), gen1(), gen1()}; - - std::shared_ptr> sout[6] = {gen1(), gen1(), gen1(), - gen1(), gen1(), gen1()}; - - // lhs = 5 = 2 ^ 3 ^ 4 - sl[0]->data()[0] = 2; - sl[1]->data()[0] = 3; - sl[2]->data()[0] = 4; - - auto p = gen1(); - - BTensor bl0(sl[0].get(), sl[1].get()); - BTensor bl1(sl[1].get(), sl[2].get()); - BTensor bl2(sl[2].get(), sl[0].get()); - - BTensor bout0(sout[0].get(), sout[1].get()); - BTensor bout1(sout[2].get(), sout[3].get()); - BTensor bout2(sout[4].get(), sout[5].get()); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - [&]() { - bl0.bitwise_not(&bout0); - bout0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[1], - [&]() { - bl1.bitwise_not(&bout1); - bout1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[2], - [&]() { - bl2.bitwise_not(&bout2); - bout2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(~5, p->data()[0]); + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1() }; + + // lhs = 5 = 2 ^ 3 ^ 4 + sl[0]->data()[0] = 2; + sl[1]->data()[0] = 3; + sl[2]->data()[0] = 4; + + auto p = gen1(); + + BTensor bl0(sl[0].get(), sl[1].get()); + BTensor bl1(sl[1].get(), sl[2].get()); + BTensor bl2(sl[2].get(), sl[0].get()); + + BTensor bout0(sout[0].get(), sout[1].get()); + BTensor bout1(sout[2].get(), sout[3].get()); + BTensor bout2(sout[4].get(), sout[5].get()); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + bl0.bitwise_not(&bout0); + bout0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + bl1.bitwise_not(&bout1); + bout1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + bl2.bitwise_not(&bout2); + bout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(~5, p->data()[0]); } TEST_F(BooleanTensorTest, lshift_test) { - std::shared_ptr> sl[3] = {gen1(), gen1(), gen1()}; - - std::shared_ptr> sout[6] = {gen1(), gen1(), gen1(), - gen1(), gen1(), gen1()}; - - // lhs = 5 = 2 ^ 3 ^ 4 - sl[0]->data()[0] = 2; - sl[1]->data()[0] = 3; - sl[2]->data()[0] = 4; - - auto p = gen1(); - - BTensor bl0(sl[0].get(), sl[1].get()); - BTensor bl1(sl[1].get(), sl[2].get()); - BTensor bl2(sl[2].get(), sl[0].get()); - - BTensor bout0(sout[0].get(), sout[1].get()); - BTensor bout1(sout[2].get(), sout[3].get()); - BTensor bout2(sout[4].get(), sout[5].get()); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - [&]() { - bl0.lshift(1, &bout0); - bout0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[1], - [&]() { - bl1.lshift(1, &bout1); - bout1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[2], - [&]() { - bl2.lshift(1, &bout2); - bout2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(5 << 1, p->data()[0]); + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1() }; + + // lhs = 5 = 2 ^ 3 ^ 4 + sl[0]->data()[0] = 2; + sl[1]->data()[0] = 3; + sl[2]->data()[0] = 4; + + auto p = gen1(); + + BTensor bl0(sl[0].get(), sl[1].get()); + BTensor bl1(sl[1].get(), sl[2].get()); + BTensor bl2(sl[2].get(), sl[0].get()); + + BTensor bout0(sout[0].get(), sout[1].get()); + BTensor bout1(sout[2].get(), sout[3].get()); + BTensor bout2(sout[4].get(), sout[5].get()); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + bl0.lshift(1, &bout0); + bout0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + bl1.lshift(1, &bout1); + bout1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + bl2.lshift(1, &bout2); + bout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(5 << 1, p->data()[0]); } TEST_F(BooleanTensorTest, rshift_test) { - std::shared_ptr> sl[3] = {gen1(), gen1(), gen1()}; - - std::shared_ptr> sout[6] = {gen1(), gen1(), gen1(), - gen1(), gen1(), gen1()}; - - // lhs = 5 = 2 ^ 3 ^ 4 - sl[0]->data()[0] = 2; - sl[1]->data()[0] = 3; - sl[2]->data()[0] = 4; - - auto p = gen1(); - - BTensor bl0(sl[0].get(), sl[1].get()); - BTensor bl1(sl[1].get(), sl[2].get()); - BTensor bl2(sl[2].get(), sl[0].get()); - - BTensor bout0(sout[0].get(), sout[1].get()); - BTensor bout1(sout[2].get(), sout[3].get()); - BTensor bout2(sout[4].get(), sout[5].get()); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - [&]() { - bl0.rshift(1, &bout0); - bout0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[1], - [&]() { - bl1.rshift(1, &bout1); - bout1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[2], - [&]() { - bl2.rshift(1, &bout2); - bout2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(5 >> 1, p->data()[0]); + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1() }; + + // lhs = 5 = 2 ^ 3 ^ 4 + sl[0]->data()[0] = 2; + sl[1]->data()[0] = 3; + sl[2]->data()[0] = 4; + + auto p = gen1(); + + BTensor bl0(sl[0].get(), sl[1].get()); + BTensor bl1(sl[1].get(), sl[2].get()); + BTensor bl2(sl[2].get(), sl[0].get()); + + BTensor bout0(sout[0].get(), sout[1].get()); + BTensor bout1(sout[2].get(), sout[3].get()); + BTensor bout2(sout[4].get(), sout[5].get()); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + bl0.rshift(1, &bout0); + bout0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + bl1.rshift(1, &bout1); + bout1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + bl2.rshift(1, &bout2); + bout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(5 >> 1, p->data()[0]); } TEST_F(BooleanTensorTest, logical_rshift_test) { - std::shared_ptr> sl[3] = {gen1(), gen1(), gen1()}; - - std::shared_ptr> sout[6] = {gen1(), gen1(), gen1(), - gen1(), gen1(), gen1()}; - - // lhs = -1 - sl[0]->data()[0] = -1; - sl[1]->data()[0] = 0; - sl[2]->data()[0] = 0; - - auto p = gen1(); - - BTensor bl0(sl[0].get(), sl[1].get()); - BTensor bl1(sl[1].get(), sl[2].get()); - BTensor bl2(sl[2].get(), sl[0].get()); - - BTensor bout0(sout[0].get(), sout[1].get()); - BTensor bout1(sout[2].get(), sout[3].get()); - BTensor bout2(sout[4].get(), sout[5].get()); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - [&]() { - bl0.logical_rshift(1, &bout0); - bout0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[1], - [&]() { - bl1.logical_rshift(1, &bout1); - bout1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[2], - [&]() { - bl2.logical_rshift(1, &bout2); - bout2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(-1ull >> 1, p->data()[0]); + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1() }; + + // lhs = -1 + sl[0]->data()[0] = -1; + sl[1]->data()[0] = 0; + sl[2]->data()[0] = 0; + + auto p = gen1(); + + BTensor bl0(sl[0].get(), sl[1].get()); + BTensor bl1(sl[1].get(), sl[2].get()); + BTensor bl2(sl[2].get(), sl[0].get()); + + BTensor bout0(sout[0].get(), sout[1].get()); + BTensor bout1(sout[2].get(), sout[3].get()); + BTensor bout2(sout[4].get(), sout[5].get()); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + bl0.logical_rshift(1, &bout0); + bout0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + bl1.logical_rshift(1, &bout1); + bout1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + bl2.logical_rshift(1, &bout2); + bout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(-1ull >> 1, p->data()[0]); } TEST_F(BooleanTensorTest, ppa_test) { - std::shared_ptr> sl[3] = {gen1(), gen1(), gen1()}; - std::shared_ptr> sr[3] = {gen1(), gen1(), gen1()}; - std::shared_ptr> sout[6] = {gen1(), gen1(), gen1(), - gen1(), gen1(), gen1()}; - - // lhs = 5 = 2 ^ 3 ^ 4 - sl[0]->data()[0] = 2; - sl[1]->data()[0] = 3; - sl[2]->data()[0] = 4; - - // rhs = 7 = 1 ^ 2 ^ 4 - sr[0]->data()[0] = 1; - sr[1]->data()[0] = 2; - sr[2]->data()[0] = 4; - - auto p = gen1(); - - BTensor bl0(sl[0].get(), sl[1].get()); - BTensor bl1(sl[1].get(), sl[2].get()); - BTensor bl2(sl[2].get(), sl[0].get()); - - BTensor br0(sr[0].get(), sr[1].get()); - BTensor br1(sr[1].get(), sr[2].get()); - BTensor br2(sr[2].get(), sr[0].get()); - - BTensor bout0(sout[0].get(), sout[1].get()); - BTensor bout1(sout[2].get(), sout[3].get()); - BTensor bout2(sout[4].get(), sout[5].get()); - - const size_t nbits = 64; - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - [&]() { - bl0.ppa(&br0, &bout0, nbits); - bout0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[1], - [&]() { - bl1.ppa(&br1, &bout1, nbits); - bout1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[2], - [&]() { - bl2.ppa(&br2, &bout2, nbits); - bout2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(5 + 7, p->data()[0]); + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sr[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1() }; + + // lhs = 5 = 2 ^ 3 ^ 4 + sl[0]->data()[0] = 2; + sl[1]->data()[0] = 3; + sl[2]->data()[0] = 4; + + // rhs = 7 = 1 ^ 2 ^ 4 + sr[0]->data()[0] = 1; + sr[1]->data()[0] = 2; + sr[2]->data()[0] = 4; + + auto p = gen1(); + + BTensor bl0(sl[0].get(), sl[1].get()); + BTensor bl1(sl[1].get(), sl[2].get()); + BTensor bl2(sl[2].get(), sl[0].get()); + + BTensor br0(sr[0].get(), sr[1].get()); + BTensor br1(sr[1].get(), sr[2].get()); + BTensor br2(sr[2].get(), sr[0].get()); + + BTensor bout0(sout[0].get(), sout[1].get()); + BTensor bout1(sout[2].get(), sout[3].get()); + BTensor bout2(sout[4].get(), sout[5].get()); + + const size_t nbits = 64; + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + bl0.ppa(&br0, &bout0, nbits); + bout0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + bl1.ppa(&br1, &bout1, nbits); + bout1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + bl2.ppa(&br2, &bout2, nbits); + bout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(5 + 7, p->data()[0]); } using FTensor = FixedPointTensor; TEST_F(BooleanTensorTest, b2a_test) { - std::shared_ptr> sl[3] = {gen1(), gen1(), gen1()}; - std::shared_ptr> sout[6] = {gen1(), gen1(), gen1(), - gen1(), gen1(), gen1()}; - - // lhs = 5 = 2 ^ 3 ^ 4 - sl[0]->data()[0] = 2; - sl[1]->data()[0] = 3; - sl[2]->data()[0] = 4; - - BTensor b0(sl[0].get(), sl[1].get()); - BTensor b1(sl[1].get(), sl[2].get()); - BTensor b2(sl[2].get(), sl[0].get()); - - FTensor f0(sout[0].get(), sout[1].get()); - FTensor f1(sout[2].get(), sout[3].get()); - FTensor f2(sout[4].get(), sout[5].get()); - - auto p = gen1(); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - [&]() { - b0.b2a(&f0); - f0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[1], - [&]() { - b1.b2a(&f1); - f1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[2], - [&]() { - b2.b2a(&f2); - f2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(5, p->data()[0]); + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1()}; + + // lhs = 5 = 2 ^ 3 ^ 4 + sl[0]->data()[0] = 2; + sl[1]->data()[0] = 3; + sl[2]->data()[0] = 4; + + BTensor b0(sl[0].get(), sl[1].get()); + BTensor b1(sl[1].get(), sl[2].get()); + BTensor b2(sl[2].get(), sl[0].get()); + + FTensor f0(sout[0].get(), sout[1].get()); + FTensor f1(sout[2].get(), sout[3].get()); + FTensor f2(sout[4].get(), sout[5].get()); + + auto p = gen1(); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + b0.b2a(&f0); + f0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + b1.b2a(&f1); + f1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + b2.b2a(&f2); + f2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(5, p->data()[0]); } TEST_F(BooleanTensorTest, a2b_test) { - std::shared_ptr> sl[3] = {gen1(), gen1(), gen1()}; - std::shared_ptr> sout[6] = {gen1(), gen1(), gen1(), - gen1(), gen1(), gen1()}; - - // lhs = 9 = 2 + 3 + 4 - sl[0]->data()[0] = 2; - sl[1]->data()[0] = 3; - sl[2]->data()[0] = 4; - - BTensor b0(sout[0].get(), sout[1].get()); - BTensor b1(sout[2].get(), sout[3].get()); - BTensor b2(sout[4].get(), sout[5].get()); - - FTensor f0(sl[0].get(), sl[1].get()); - FTensor f1(sl[1].get(), sl[2].get()); - FTensor f2(sl[2].get(), sl[0].get()); - - auto p = gen1(); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - [&]() { - b0 = &f0; - b0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[1], - [&]() { - b1 = &f1; - b1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[2], - [&]() { - b2 = &f2; - b2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(9, p->data()[0]); + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1()}; + + // lhs = 9 = 2 + 3 + 4 + sl[0]->data()[0] = 2; + sl[1]->data()[0] = 3; + sl[2]->data()[0] = 4; + + BTensor b0(sout[0].get(), sout[1].get()); + BTensor b1(sout[2].get(), sout[3].get()); + BTensor b2(sout[4].get(), sout[5].get()); + + FTensor f0(sl[0].get(), sl[1].get()); + FTensor f1(sl[1].get(), sl[2].get()); + FTensor f2(sl[2].get(), sl[0].get()); + + auto p = gen1(); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + b0 = &f0; + b0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + b1 = &f1; + b1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + b2 = &f2; + b2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(9, p->data()[0]); } TEST_F(BooleanTensorTest, bit_extract_test) { - std::shared_ptr> sl[3] = {gen1(), gen1(), gen1()}; - std::shared_ptr> sout[6] = {gen1(), gen1(), gen1(), - gen1(), gen1(), gen1()}; - - // lhs = 9 = 2 + 3 + 4 - sl[0]->data()[0] = 2; - sl[1]->data()[0] = 3; - sl[2]->data()[0] = 4; - - FTensor f0(sl[0].get(), sl[1].get()); - FTensor f1(sl[1].get(), sl[2].get()); - FTensor f2(sl[2].get(), sl[0].get()); - - BTensor b0(sout[0].get(), sout[1].get()); - BTensor b1(sout[2].get(), sout[3].get()); - BTensor b2(sout[4].get(), sout[5].get()); - - auto p = gen1(); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - [&]() { - b0.bit_extract(3, &f0); - b0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[1], - [&]() { - b1.bit_extract(3, &f1); - b1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[2], - [&]() { - b2.bit_extract(3, &f2); - b2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(1, p->data()[0]); + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1()}; + + // lhs = 9 = 2 + 3 + 4 + sl[0]->data()[0] = 2; + sl[1]->data()[0] = 3; + sl[2]->data()[0] = 4; + + FTensor f0(sl[0].get(), sl[1].get()); + FTensor f1(sl[1].get(), sl[2].get()); + FTensor f2(sl[2].get(), sl[0].get()); + + BTensor b0(sout[0].get(), sout[1].get()); + BTensor b1(sout[2].get(), sout[3].get()); + BTensor b2(sout[4].get(), sout[5].get()); + + auto p = gen1(); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + b0.bit_extract(3, &f0); + b0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + b1.bit_extract(3, &f1); + b1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + b2.bit_extract(3, &f2); + b2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(1, p->data()[0]); } TEST_F(BooleanTensorTest, boolean_bit_extract_test) { - std::shared_ptr> sl[3] = {gen1(), gen1(), gen1()}; - std::shared_ptr> sout[6] = {gen1(), gen1(), gen1(), - gen1(), gen1(), gen1()}; - - // lhs = 5 = 2 ^ 3 ^ 4 - sl[0]->data()[0] = 2; - sl[1]->data()[0] = 3; - sl[2]->data()[0] = 4; - - BTensor bl0(sl[0].get(), sl[1].get()); - BTensor bl1(sl[1].get(), sl[2].get()); - BTensor bl2(sl[2].get(), sl[0].get()); - - BTensor b0(sout[0].get(), sout[1].get()); - BTensor b1(sout[2].get(), sout[3].get()); - BTensor b2(sout[4].get(), sout[5].get()); - - auto p = gen1(); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - [&]() { - bl0.bit_extract(2, &b0); - b0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[1], - [&]() { - bl1.bit_extract(2, &b1); - b1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[2], - [&]() { - bl2.bit_extract(2, &b2); - b2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(1, p->data()[0]); + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1()}; + + // lhs = 5 = 2 ^ 3 ^ 4 + sl[0]->data()[0] = 2; + sl[1]->data()[0] = 3; + sl[2]->data()[0] = 4; + + BTensor bl0(sl[0].get(), sl[1].get()); + BTensor bl1(sl[1].get(), sl[2].get()); + BTensor bl2(sl[2].get(), sl[0].get()); + + BTensor b0(sout[0].get(), sout[1].get()); + BTensor b1(sout[2].get(), sout[3].get()); + BTensor b2(sout[4].get(), sout[5].get()); + + auto p = gen1(); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + bl0.bit_extract(2, &b0); + b0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + bl1.bit_extract(2, &b1); + b1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + bl2.bit_extract(2, &b2); + b2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(1, p->data()[0]); } TEST_F(BooleanTensorTest, bit_extract_test2) { - std::shared_ptr> sl[3] = {gen1(), gen1(), gen1()}; - std::shared_ptr> sout[6] = {gen1(), gen1(), gen1(), - gen1(), gen1(), gen1()}; - - // lhs = -9 = -2 + -3 + -4 - sl[0]->data()[0] = -2; - sl[1]->data()[0] = -3; - sl[2]->data()[0] = -4; - - FTensor f0(sl[0].get(), sl[1].get()); - FTensor f1(sl[1].get(), sl[2].get()); - FTensor f2(sl[2].get(), sl[0].get()); - - BTensor b0(sout[0].get(), sout[1].get()); - BTensor b1(sout[2].get(), sout[3].get()); - BTensor b2(sout[4].get(), sout[5].get()); - - auto p = gen1(); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - [&]() { - b0.bit_extract(63, &f0); - b0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[1], - [&]() { - b1.bit_extract(63, &f1); - b1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[2], - [&]() { - b2.bit_extract(63, &f2); - b2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(1, p->data()[0]); + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1()}; + + // lhs = -9 = -2 + -3 + -4 + sl[0]->data()[0] = -2; + sl[1]->data()[0] = -3; + sl[2]->data()[0] = -4; + + FTensor f0(sl[0].get(), sl[1].get()); + FTensor f1(sl[1].get(), sl[2].get()); + FTensor f2(sl[2].get(), sl[0].get()); + + BTensor b0(sout[0].get(), sout[1].get()); + BTensor b1(sout[2].get(), sout[3].get()); + BTensor b2(sout[4].get(), sout[5].get()); + + auto p = gen1(); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + b0.bit_extract(63, &f0); + b0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + b1.bit_extract(63, &f1); + b1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + b2.bit_extract(63, &f2); + b2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(1, p->data()[0]); } TEST_F(BooleanTensorTest, bit_extract_test3) { - std::vector shape = {2, 2}; - std::shared_ptr> sl[3] = {gen(shape), gen(shape), - gen(shape)}; - std::shared_ptr> sout[6] = { - gen(shape), gen(shape), gen(shape), gen(shape), gen(shape), gen(shape)}; - - // lhs = -65536 - sl[0]->data()[0] = 626067816440182033; - sl[1]->data()[0] = 1108923486657625775; - sl[2]->data()[0] = -1734991303097873344; - - sl[0]->data()[1] = -1320209182212830031; - sl[1]->data()[1] = 3175682926293206038; - sl[2]->data()[1] = -1855473744080441543; - - sl[0]->data()[2] = -7241979567589308516; - sl[1]->data()[2] = 5579083190137080035; - sl[2]->data()[2] = 1662896377452162945; - - sl[0]->data()[3] = 1468124374943170272; - sl[1]->data()[3] = -4796789375126030707; - sl[2]->data()[3] = 3328665000182794899; - - FTensor f0(sl[0].get(), sl[1].get()); - FTensor f1(sl[1].get(), sl[2].get()); - FTensor f2(sl[2].get(), sl[0].get()); - - BTensor b0(sout[0].get(), sout[1].get()); - BTensor b1(sout[2].get(), sout[3].get()); - BTensor b2(sout[4].get(), sout[5].get()); - - auto p = gen(shape); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - [&]() { - b0.bit_extract(63, &f0); - b0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[1], - [&]() { - b1.bit_extract(63, &f1); - b1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[2], - [&]() { - b2.bit_extract(63, &f2); - b2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(1, p->data()[0]); - EXPECT_EQ(1, p->data()[1]); - EXPECT_EQ(1, p->data()[2]); - EXPECT_EQ(1, p->data()[3]); + std::vector shape = {2, 2}; + std::shared_ptr> sl[3] = { gen(shape), gen(shape), gen(shape) }; + std::shared_ptr> sout[6] = { gen(shape), gen(shape), gen(shape), + gen(shape), gen(shape), gen(shape)}; + + // lhs = -65536 + sl[0]->data()[0] = 626067816440182033; + sl[1]->data()[0] = 1108923486657625775; + sl[2]->data()[0] = -1734991303097873344; + + sl[0]->data()[1] = -1320209182212830031 ; + sl[1]->data()[1] = 3175682926293206038; + sl[2]->data()[1] = -1855473744080441543; + + sl[0]->data()[2] = -7241979567589308516; + sl[1]->data()[2] = 5579083190137080035; + sl[2]->data()[2] = 1662896377452162945; + + sl[0]->data()[3] = 1468124374943170272; + sl[1]->data()[3] = -4796789375126030707; + sl[2]->data()[3] = 3328665000182794899; + + FTensor f0(sl[0].get(), sl[1].get()); + FTensor f1(sl[1].get(), sl[2].get()); + FTensor f2(sl[2].get(), sl[0].get()); + + BTensor b0(sout[0].get(), sout[1].get()); + BTensor b1(sout[2].get(), sout[3].get()); + BTensor b2(sout[4].get(), sout[5].get()); + + auto p = gen(shape); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + b0.bit_extract(63, &f0); + b0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + b1.bit_extract(63, &f1); + b1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + b2.bit_extract(63, &f2); + b2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(1, p->data()[0]); + EXPECT_EQ(1, p->data()[1]); + EXPECT_EQ(1, p->data()[2]); + EXPECT_EQ(1, p->data()[3]); } TEST_F(BooleanTensorTest, abmul_test) { - std::shared_ptr> sl[3] = {gen1(), gen1(), gen1()}; - std::shared_ptr> sout[6] = {gen1(), gen1(), gen1(), - gen1(), gen1(), gen1()}; - - // lhs = 1 - sl[0]->data()[0] = 1; - sl[1]->data()[0] = 0; - sl[2]->data()[0] = 0; - - BTensor b0(sl[0].get(), sl[1].get()); - BTensor b1(sl[1].get(), sl[2].get()); - BTensor b2(sl[2].get(), sl[0].get()); - - FTensor f0(sout[0].get(), sout[1].get()); - FTensor f1(sout[2].get(), sout[3].get()); - FTensor f2(sout[4].get(), sout[5].get()); - - auto p = gen1(); - - // rhs = 7 - p->data()[0] = 7; - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - [&]() { - b0.mul(p.get(), &f0, 0); - f0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[1], - [&]() { - b1.mul(nullptr, &f1, 0); - f1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[2], - [&]() { - b2.mul(nullptr, &f2, 0); - f2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(1 * 7, p->data()[0]); + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1()}; + + // lhs = 1 + sl[0]->data()[0] = 1; + sl[1]->data()[0] = 0; + sl[2]->data()[0] = 0; + + BTensor b0(sl[0].get(), sl[1].get()); + BTensor b1(sl[1].get(), sl[2].get()); + BTensor b2(sl[2].get(), sl[0].get()); + + FTensor f0(sout[0].get(), sout[1].get()); + FTensor f1(sout[2].get(), sout[3].get()); + FTensor f2(sout[4].get(), sout[5].get()); + + auto p = gen1(); + + // rhs = 7 + p->data()[0] = 7; + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + b0.mul(p.get(), &f0, 0); + f0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + b1.mul(nullptr, &f1, 0); + f1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + b2.mul(nullptr, &f2, 0); + f2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(1 * 7, p->data()[0]); } TEST_F(BooleanTensorTest, abmul2_test) { - std::shared_ptr> sl[3] = {gen1(), gen1(), gen1()}; - std::shared_ptr> sr[3] = {gen1(), gen1(), gen1()}; - std::shared_ptr> sout[6] = {gen1(), gen1(), gen1(), - gen1(), gen1(), gen1()}; - - // lhs = 1 - sl[0]->data()[0] = 1; - sl[1]->data()[0] = 0; - sl[2]->data()[0] = 0; - - // rhs = 12 = 3 + 4 + 5 - sr[0]->data()[0] = 3; - sr[1]->data()[0] = 4; - sr[2]->data()[0] = 5; - - BTensor bl0(sl[0].get(), sl[1].get()); - BTensor bl1(sl[1].get(), sl[2].get()); - BTensor bl2(sl[2].get(), sl[0].get()); - - FTensor fr0(sr[0].get(), sr[1].get()); - FTensor fr1(sr[1].get(), sr[2].get()); - FTensor fr2(sr[2].get(), sr[0].get()); - - FTensor fout0(sout[0].get(), sout[1].get()); - FTensor fout1(sout[2].get(), sout[3].get()); - FTensor fout2(sout[4].get(), sout[5].get()); - - auto p = gen1(); - - _t[0] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[0], - [&]() { - bl0.mul(&fr0, &fout0); - fout0.reveal_to_one(0, p.get()); - }); - }); - - _t[1] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[1], - [&]() { - bl1.mul(&fr1, &fout1); - fout1.reveal_to_one(0, nullptr); - }); - }); - - _t[2] = std::thread([&]() { - ContextHolder::template run_with_context(_exec_ctx.get(), _mpc_ctx[2], - [&]() { - bl2.mul(&fr2, &fout2); - fout2.reveal_to_one(0, nullptr); - }); - }); - for (auto &t : _t) { - t.join(); - } - EXPECT_EQ(1 * 12, p->data()[0]); + std::shared_ptr> sl[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sr[3] = { gen1(), gen1(), gen1() }; + std::shared_ptr> sout[6] = { gen1(), gen1(), gen1(), + gen1(), gen1(), gen1()}; + + // lhs = 1 + sl[0]->data()[0] = 1; + sl[1]->data()[0] = 0; + sl[2]->data()[0] = 0; + + // rhs = 12 = 3 + 4 + 5 + sr[0]->data()[0] = 3; + sr[1]->data()[0] = 4; + sr[2]->data()[0] = 5; + + BTensor bl0(sl[0].get(), sl[1].get()); + BTensor bl1(sl[1].get(), sl[2].get()); + BTensor bl2(sl[2].get(), sl[0].get()); + + FTensor fr0(sr[0].get(), sr[1].get()); + FTensor fr1(sr[1].get(), sr[2].get()); + FTensor fr2(sr[2].get(), sr[0].get()); + + FTensor fout0(sout[0].get(), sout[1].get()); + FTensor fout1(sout[2].get(), sout[3].get()); + FTensor fout2(sout[4].get(), sout[5].get()); + + auto p = gen1(); + + _t[0] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + bl0.mul(&fr0, &fout0); + fout0.reveal_to_one(0, p.get()); + }); + } + ); + + _t[1] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + bl1.mul(&fr1, &fout1); + fout1.reveal_to_one(0, nullptr); + }); + } + ); + + _t[2] = std::thread( + [&] () { + ContextHolder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + bl2.mul(&fr2, &fout2); + fout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(1 * 12, p->data()[0]); } } // namespace aby3 diff --git a/core/privc3/fixedpoint_tensor_test.cc b/core/privc3/fixedpoint_tensor_test.cc index 5dc905c803d0d2a26b94ae0c784e723d50be2550..0828594dc18805a2a5bff568a820d2e151b7f8d5 100644 --- a/core/privc3/fixedpoint_tensor_test.cc +++ b/core/privc3/fixedpoint_tensor_test.cc @@ -1,2510 +1,2562 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include +#include "gtest/gtest.h" +#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/tensor.h" -#include "gtest/gtest.h" -#include "fixedpoint_tensor.h" -#include "core/paddlefl_mpc/mpc_protocol/context_holder.h" #include "core/paddlefl_mpc/mpc_protocol/mesh_network.h" +#include "core/paddlefl_mpc/mpc_protocol/context_holder.h" +#include "fixedpoint_tensor.h" namespace aby3 { -using g_ctx_holder = paddle::mpc::ContextHolder; -using Fix64N16 = FixedPointTensor; + using g_ctx_holder = paddle::mpc::ContextHolder; + using Fix64N16 = FixedPointTensor; class FixedTensorTest : public ::testing::Test { public: - paddle::platform::CPUDeviceContext _cpu_ctx; - std::shared_ptr _exec_ctx; - std::shared_ptr _mpc_ctx[3]; - std::shared_ptr _store; - std::thread _t[3]; - std::shared_ptr _s_tensor_factory; - - void SetUp() { - - paddle::framework::OperatorBase *op = nullptr; - paddle::framework::Scope scope; - paddle::framework::RuntimeContext ctx({}, {}); - // only device_ctx is needed - _exec_ctx = std::make_shared( - *op, scope, _cpu_ctx, ctx, nullptr); - - _store = std::make_shared(); - - std::thread t[3]; - for (size_t i = 0; i < 3; ++i) { - _t[i] = std::thread(&FixedTensorTest::gen_mpc_ctx, this, i); + + paddle::platform::CPUDeviceContext _cpu_ctx; + std::shared_ptr _exec_ctx; + std::shared_ptr _mpc_ctx[3]; + std::shared_ptr _store; + std::thread _t[3]; + std::shared_ptr _s_tensor_factory; + + virtual ~FixedTensorTest() noexcept {} + + void SetUp() { + + paddle::framework::OperatorBase* op = nullptr; + paddle::framework::Scope scope; + paddle::framework::RuntimeContext ctx({}, {}); + // only device_ctx is needed + _exec_ctx = std::make_shared( + *op, scope, _cpu_ctx, ctx); + + _store = std::make_shared(); + + std::thread t[3]; + for (size_t i = 0; i < 3; ++i) { + _t[i] = std::thread(&FixedTensorTest::gen_mpc_ctx, this, i); + } + for (auto& ti : _t) { + ti.join(); + } + _s_tensor_factory = std::make_shared(&_cpu_ctx); } - for (auto &ti : _t) { - ti.join(); + std::shared_ptr gen_network(size_t idx) { + return std::make_shared(idx, + "127.0.0.1", + 3, + "test_prefix", + _store); + } + void gen_mpc_ctx(size_t idx) { + auto net = gen_network(idx); + net->init(); + _mpc_ctx[idx] = std::make_shared(idx, net); + } + + std::shared_ptr> gen(std::vector shape) { + return _s_tensor_factory->template create(shape); } - _s_tensor_factory = std::make_shared(&_cpu_ctx); - } - std::shared_ptr gen_network(size_t idx) { - return std::make_shared(idx, "127.0.0.1", 3, - "test_prefix", _store); - } - void gen_mpc_ctx(size_t idx) { - auto net = gen_network(idx); - net->init(); - _mpc_ctx[idx] = std::make_shared(idx, net); - } - - std::shared_ptr> gen(std::vector shape) { - return _s_tensor_factory->template create(shape); - } }; std::shared_ptr> gen(std::vector shape) { - return g_ctx_holder::tensor_factory()->template create(shape); -} - -template -PaddleTensor -test_fixedt_gen_paddle_tensor(std::vector &input, - std::vector &shape, - paddle::platform::CPUDeviceContext &cpu_ctx) { - - PaddleTensor ret(&cpu_ctx); - ret.reshape(shape); - T *ret_ptr = ret.data(); - for (int i = 0; i < ret.numel(); i++) { - *(ret_ptr + i) = (T)(input[i] * pow(2, N)); - } - return ret; -} - -template -bool test_fixedt_check_tensor_eq(const TensorAdapter *in1, - const TensorAdapter *in2, - double precision = 0.0001) { - // check shape - std::vector shape1, shape2; - shape1 = in1->shape(); - shape2 = in2->shape(); - size_t scale = in1->scaling_factor(); - if (shape1.size() != shape2.size()) { - std::cout << "shape size error: shape1.size: " << shape1.size() - << "; shape2.size: " << shape2.size() << std::endl; - return false; - } - for (int i = 0; i < shape1.size(); i++) { - if (shape1[i] != shape2[i]) { - std::cout << "shape error!" << std::endl; - return false; + return g_ctx_holder::tensor_factory()->template create(shape); +} + +template +PaddleTensor test_fixedt_gen_paddle_tensor(std::vector& input, + std::vector& shape, + paddle::platform::CPUDeviceContext& cpu_ctx) { + + PaddleTensor ret(&cpu_ctx); + ret.reshape(shape); + T* ret_ptr = ret.data(); + for (int i = 0; i < ret.numel(); i++) { + *(ret_ptr + i) = (T) (input[i] * pow(2, N)); + } + return ret; +} + +template +bool test_fixedt_check_tensor_eq(const TensorAdapter* in1, + const TensorAdapter* in2, double precision = 0.0001) { + // check shape + std::vector shape1, shape2; + shape1 = in1->shape(); + shape2 = in2->shape(); + size_t scale = in1->scaling_factor(); + if (shape1.size() != shape2.size()) { + std::cout << "shape size error: shape1.size: "<numel(); i++) { - if (std::abs(*(in1->data() + i) - *(in2->data() + i)) > - precision * pow(2, scale)) { - std::cout << "result error: inx: " << i - << " in1[i] = " << *(in1->data() + i) - << " in2[i] = " << *(in2->data() + i) << std::endl; - return false; + for (int i = 0; i < shape1.size(); i++) { + if (shape1[i] != shape2[i]) { + std::cout << "shape error!"<numel(); i++) { + if (std::abs(*(in1->data() + i) - *(in2->data() + i)) > + precision * pow(2, scale)) { + std::cout << "result error: inx: "< _tensor_factory; - CPUDeviceContext _cpu_ctx; + std::shared_ptr _tensor_factory; + + CPUDeviceContext _cpu_ctx; - void SetUp() { - _tensor_factory = std::make_shared(&_cpu_ctx); - } + virtual ~PaddleTensorTest() noexcept {} + + void SetUp() { + _tensor_factory = std::make_shared(&_cpu_ctx); + } }; TEST_F(PaddleTensorTest, factory_test) { - EXPECT_NO_THROW(_tensor_factory->template create()); - std::vector shape = {2, 3}; - EXPECT_NO_THROW(_tensor_factory->template create(shape)); + EXPECT_NO_THROW(_tensor_factory->template create()); + std::vector shape = { 2, 3 }; + EXPECT_NO_THROW(_tensor_factory->template create(shape)); } TEST_F(PaddleTensorTest, ctor_test) { - Tensor t; - // t holds no memory - EXPECT_THROW({ PaddleTensor pt(&_cpu_ctx, t); }, - ::paddle::platform::EnforceNotMet); - t.template mutable_data(_cpu_ctx.GetPlace()); - EXPECT_NO_THROW({ PaddleTensor pt(&_cpu_ctx, t); }); + Tensor t; + // t holds no memory + EXPECT_THROW({ PaddleTensor pt(&_cpu_ctx, t); }, ::paddle::platform::EnforceNotMet); + t.template mutable_data(_cpu_ctx.GetPlace()); + EXPECT_NO_THROW({ PaddleTensor pt(&_cpu_ctx, t); }); } TEST_F(PaddleTensorTest, shape_test) { - std::vector shape = {2, 3}; - auto pt = _tensor_factory->template create(shape); + std::vector shape = { 2, 3 }; + auto pt = _tensor_factory->template create(shape); - EXPECT_EQ(shape.size(), pt->shape().size()); + EXPECT_EQ(shape.size(), pt->shape().size()); - bool eq = std::equal(shape.begin(), shape.end(), pt->shape().begin()); - EXPECT_TRUE(eq); + bool eq = std::equal(shape.begin(), shape.end(), pt->shape().begin()); + EXPECT_TRUE(eq); - EXPECT_EQ(6u, pt->numel()); + EXPECT_EQ(6u, pt->numel()); } TEST_F(PaddleTensorTest, reshape_test) { - std::vector shape = {2, 3}; - auto pt = _tensor_factory->template create(); + std::vector shape = { 2, 3 }; + auto pt = _tensor_factory->template create(); - pt->reshape(shape); + pt->reshape(shape); - EXPECT_EQ(shape.size(), pt->shape().size()); + EXPECT_EQ(shape.size(), pt->shape().size()); - bool eq = std::equal(shape.begin(), shape.end(), pt->shape().begin()); - EXPECT_TRUE(eq); + bool eq = std::equal(shape.begin(), shape.end(), pt->shape().begin()); + EXPECT_TRUE(eq); } TEST_F(PaddleTensorTest, add_test) { - std::vector shape = {1}; - auto pt0 = _tensor_factory->template create(shape); - auto pt1 = _tensor_factory->template create(shape); - auto pt2 = _tensor_factory->template create(shape); - pt0->data()[0] = 1; - pt1->data()[0] = 2; - pt0->add(pt1.get(), pt2.get()); - - EXPECT_EQ(3, pt2->data()[0]); + std::vector shape = { 1 }; + auto pt0 = _tensor_factory->template create(shape); + auto pt1 = _tensor_factory->template create(shape); + auto pt2 = _tensor_factory->template create(shape); + pt0->data()[0] = 1; + pt1->data()[0] = 2; + pt0->add(pt1.get(), pt2.get()); + + EXPECT_EQ(3, pt2->data()[0]); } TEST_F(PaddleTensorTest, sub_test) { - std::vector shape = {1}; - auto pt0 = _tensor_factory->template create(shape); - auto pt1 = _tensor_factory->template create(shape); - auto pt2 = _tensor_factory->template create(shape); - pt0->data()[0] = 2; - pt1->data()[0] = 1; - pt0->sub(pt1.get(), pt2.get()); - - EXPECT_EQ(1, pt2->data()[0]); + std::vector shape = { 1 }; + auto pt0 = _tensor_factory->template create(shape); + auto pt1 = _tensor_factory->template create(shape); + auto pt2 = _tensor_factory->template create(shape); + pt0->data()[0] = 2; + pt1->data()[0] = 1; + pt0->sub(pt1.get(), pt2.get()); + + EXPECT_EQ(1, pt2->data()[0]); } TEST_F(PaddleTensorTest, negative_test) { - std::vector shape = {1}; - auto pt0 = _tensor_factory->template create(shape); - auto pt1 = _tensor_factory->template create(shape); - pt0->data()[0] = 2; - pt0->negative(pt1.get()); + std::vector shape = { 1 }; + auto pt0 = _tensor_factory->template create(shape); + auto pt1 = _tensor_factory->template create(shape); + pt0->data()[0] = 2; + pt0->negative(pt1.get()); - EXPECT_EQ(-2, pt1->data()[0]); + EXPECT_EQ(-2, pt1->data()[0]); } TEST_F(PaddleTensorTest, mul_test) { - std::vector shape = {1}; - auto pt0 = _tensor_factory->template create(shape); - auto pt1 = _tensor_factory->template create(shape); - auto pt2 = _tensor_factory->template create(shape); - pt0->data()[0] = 7; - pt1->data()[0] = 3; - pt0->mul(pt1.get(), pt2.get()); - - EXPECT_EQ(21, pt2->data()[0]); + std::vector shape = { 1 }; + auto pt0 = _tensor_factory->template create(shape); + auto pt1 = _tensor_factory->template create(shape); + auto pt2 = _tensor_factory->template create(shape); + pt0->data()[0] = 7; + pt1->data()[0] = 3; + pt0->mul(pt1.get(), pt2.get()); + + EXPECT_EQ(21, pt2->data()[0]); } TEST_F(PaddleTensorTest, div_test) { - std::vector shape = {1}; - auto pt0 = _tensor_factory->template create(shape); - auto pt1 = _tensor_factory->template create(shape); - auto pt2 = _tensor_factory->template create(shape); - pt0->data()[0] = 7; - pt1->data()[0] = 3; - pt0->div(pt1.get(), pt2.get()); - - EXPECT_EQ(2, pt2->data()[0]); + std::vector shape = { 1 }; + auto pt0 = _tensor_factory->template create(shape); + auto pt1 = _tensor_factory->template create(shape); + auto pt2 = _tensor_factory->template create(shape); + pt0->data()[0] = 7; + pt1->data()[0] = 3; + pt0->div(pt1.get(), pt2.get()); + + EXPECT_EQ(2, pt2->data()[0]); } TEST_F(PaddleTensorTest, matmul_test) { - std::vector shape0 = {2, 3}; - std::vector shape1 = {3, 2}; - std::vector shape2 = {2, 2}; - auto pt0 = _tensor_factory->template create(shape0); - auto pt1 = _tensor_factory->template create(shape1); - auto pt2 = _tensor_factory->template create(shape2); - for (size_t i = 0; i < 6; ++i) { - pt0->data()[i] = i; - pt1->data()[i] = i; - } - pt0->mat_mul(pt1.get(), pt2.get()); - - // | 0 1 2 | | 0 1 | | 10 13 | - // | 3 4 5 | x | 2 3 | = | 28 40 | - // | 4 5 | - - std::vector res = {10, 13, 28, 40}; - - bool eq = std::equal(res.begin(), res.end(), pt2->data()); - - EXPECT_TRUE(eq); + std::vector shape0 = { 2, 3 }; + std::vector shape1 = { 3, 2 }; + std::vector shape2 = { 2, 2 }; + auto pt0 = _tensor_factory->template create(shape0); + auto pt1 = _tensor_factory->template create(shape1); + auto pt2 = _tensor_factory->template create(shape2); + for (size_t i = 0; i < 6; ++i) { + pt0->data()[i] = i; + pt1->data()[i] = i; + } + pt0->mat_mul(pt1.get(), pt2.get()); + + // | 0 1 2 | | 0 1 | | 10 13 | + // | 3 4 5 | x | 2 3 | = | 28 40 | + // | 4 5 | + + std::vector res = { 10, 13, 28, 40 }; + + bool eq = std::equal(res.begin(), res.end(), pt2->data()); + + EXPECT_TRUE(eq); } TEST_F(PaddleTensorTest, xor_test) { - std::vector shape = {1}; - auto pt0 = _tensor_factory->template create(shape); - auto pt1 = _tensor_factory->template create(shape); - auto pt2 = _tensor_factory->template create(shape); - pt0->data()[0] = 3; - pt1->data()[0] = 7; - pt0->bitwise_xor(pt1.get(), pt2.get()); - - EXPECT_EQ(4, pt2->data()[0]); + std::vector shape = { 1 }; + auto pt0 = _tensor_factory->template create(shape); + auto pt1 = _tensor_factory->template create(shape); + auto pt2 = _tensor_factory->template create(shape); + pt0->data()[0] = 3; + pt1->data()[0] = 7; + pt0->bitwise_xor(pt1.get(), pt2.get()); + + EXPECT_EQ(4, pt2->data()[0]); } TEST_F(PaddleTensorTest, and_test) { - std::vector shape = {1}; - auto pt0 = _tensor_factory->template create(shape); - auto pt1 = _tensor_factory->template create(shape); - auto pt2 = _tensor_factory->template create(shape); - pt0->data()[0] = 3; - pt1->data()[0] = 7; - pt0->bitwise_and(pt1.get(), pt2.get()); - - EXPECT_EQ(3, pt2->data()[0]); + std::vector shape = { 1 }; + auto pt0 = _tensor_factory->template create(shape); + auto pt1 = _tensor_factory->template create(shape); + auto pt2 = _tensor_factory->template create(shape); + pt0->data()[0] = 3; + pt1->data()[0] = 7; + pt0->bitwise_and(pt1.get(), pt2.get()); + + EXPECT_EQ(3, pt2->data()[0]); } TEST_F(PaddleTensorTest, or_test) { - std::vector shape = {1}; - auto pt0 = _tensor_factory->template create(shape); - auto pt1 = _tensor_factory->template create(shape); - auto pt2 = _tensor_factory->template create(shape); - pt0->data()[0] = 3; - pt1->data()[0] = 7; - pt0->bitwise_or(pt1.get(), pt2.get()); - - EXPECT_EQ(7, pt2->data()[0]); + std::vector shape = { 1 }; + auto pt0 = _tensor_factory->template create(shape); + auto pt1 = _tensor_factory->template create(shape); + auto pt2 = _tensor_factory->template create(shape); + pt0->data()[0] = 3; + pt1->data()[0] = 7; + pt0->bitwise_or(pt1.get(), pt2.get()); + + EXPECT_EQ(7, pt2->data()[0]); } TEST_F(PaddleTensorTest, not_test) { - std::vector shape = {1}; - auto pt0 = _tensor_factory->template create(shape); - auto pt1 = _tensor_factory->template create(shape); - pt0->data()[0] = 0; - pt0->bitwise_not(pt1.get()); + std::vector shape = { 1 }; + auto pt0 = _tensor_factory->template create(shape); + auto pt1 = _tensor_factory->template create(shape); + pt0->data()[0] = 0; + pt0->bitwise_not(pt1.get()); - EXPECT_EQ(-1, pt1->data()[0]); + EXPECT_EQ(-1, pt1->data()[0]); } TEST_F(PaddleTensorTest, lshift_test) { - std::vector shape = {1}; - auto pt0 = _tensor_factory->template create(shape); - auto pt1 = _tensor_factory->template create(shape); - pt0->data()[0] = 2; - pt0->lshift(1, pt1.get()); + std::vector shape = { 1 }; + auto pt0 = _tensor_factory->template create(shape); + auto pt1 = _tensor_factory->template create(shape); + pt0->data()[0] = 2; + pt0->lshift(1, pt1.get()); - EXPECT_EQ(4, pt1->data()[0]); + EXPECT_EQ(4, pt1->data()[0]); } TEST_F(PaddleTensorTest, rshift_test) { - std::vector shape = {1}; - auto pt0 = _tensor_factory->template create(shape); - auto pt1 = _tensor_factory->template create(shape); - pt0->data()[0] = 2; - pt0->rshift(1, pt1.get()); + std::vector shape = { 1 }; + auto pt0 = _tensor_factory->template create(shape); + auto pt1 = _tensor_factory->template create(shape); + pt0->data()[0] = 2; + pt0->rshift(1, pt1.get()); - EXPECT_EQ(1, pt1->data()[0]); + EXPECT_EQ(1, pt1->data()[0]); } TEST_F(PaddleTensorTest, logical_rshift_test) { - std::vector shape = {1}; - auto pt0 = _tensor_factory->template create(shape); - auto pt1 = _tensor_factory->template create(shape); - pt0->data()[0] = -1; - pt0->logical_rshift(1, pt1.get()); + std::vector shape = { 1 }; + auto pt0 = _tensor_factory->template create(shape); + auto pt1 = _tensor_factory->template create(shape); + pt0->data()[0] = -1; + pt0->logical_rshift(1, pt1.get()); - EXPECT_EQ(-1ull >> 1, pt1->data()[0]); + EXPECT_EQ(-1ull >> 1, pt1->data()[0]); } + TEST_F(PaddleTensorTest, scale_test) { - auto pt = _tensor_factory->template create(); + auto pt = _tensor_factory->template create(); - auto pt_ = dynamic_cast *>(pt.get()); + auto pt_ = dynamic_cast*>(pt.get()); - pt_->scaling_factor() = 1; + pt_->scaling_factor() = 1; - Tensor t; + Tensor t; - int dim[1] = {1}; - paddle::framework::DDim ddim(dim, 1); - t.template mutable_data(ddim, _cpu_ctx.GetPlace()); + int dim[1] = { 1 }; + paddle::framework::DDim ddim(dim, 1); + t.template mutable_data(ddim, _cpu_ctx.GetPlace()); - t.template data()[0] = 0.25f; + t.template data()[0] = 0.25f; - pt_->template from_float_point_type(t, 2); + pt_->template from_float_point_type(t, 2); - EXPECT_EQ(2, pt_->scaling_factor()); - EXPECT_EQ(1, pt->data()[0]); + EXPECT_EQ(2, pt_->scaling_factor()); + EXPECT_EQ(1, pt->data()[0]); } TEST_F(PaddleTensorTest, scalar_test) { - auto pt = _tensor_factory->template create(); + auto pt = _tensor_factory->template create(); - auto pt_ = dynamic_cast *>(pt.get()); + auto pt_ = dynamic_cast*>(pt.get()); - pt_->scaling_factor() = 1; + pt_->scaling_factor() = 1; - std::vector shape = {2}; - pt_->template from_float_point_scalar(0.25f, shape, 2); + std::vector shape = { 2 }; + pt_->template from_float_point_scalar(0.25f, shape, 2); - EXPECT_EQ(2, pt_->scaling_factor()); - EXPECT_EQ(1, pt->data()[0]); - EXPECT_EQ(1, pt->data()[1]); + EXPECT_EQ(2, pt_->scaling_factor()); + EXPECT_EQ(1, pt->data()[0]); + EXPECT_EQ(1, pt->data()[1]); } TEST_F(PaddleTensorTest, slice_test) { - std::vector shape = {2, 2}; - auto pt = _tensor_factory->template create(shape); - auto ret = _tensor_factory->template create(); + std::vector shape = { 2, 2 }; + auto pt = _tensor_factory->template create(shape); + auto ret = _tensor_factory->template create(); - auto pt_ = dynamic_cast *>(pt.get()); - pt_->scaling_factor() = 1; + auto pt_ = dynamic_cast*>(pt.get()); + pt_->scaling_factor() = 1; - for (size_t i = 0; i < 4; ++i) { - pt->data()[0] = i; - } + for (size_t i = 0; i < 4; ++i) { + pt->data()[0] = i; + } - pt_->slice(1, 2, ret.get()); + pt_->slice(1, 2, ret.get()); - auto shape_ = ret->shape(); + auto shape_ = ret->shape(); - EXPECT_EQ(2, shape_.size()); - EXPECT_EQ(1, shape_[0]); - EXPECT_EQ(2, shape_[1]); + EXPECT_EQ(2, shape_.size()); + EXPECT_EQ(1, shape_[0]); + EXPECT_EQ(2, shape_[1]); - EXPECT_EQ(1, ret->scaling_factor()); + EXPECT_EQ(1, ret->scaling_factor()); - EXPECT_EQ(2, ret->data()[0]); - EXPECT_EQ(3, ret->data()[1]); + EXPECT_EQ(2, ret->data()[0]); + EXPECT_EQ(3, ret->data()[1]); } } // namespace aby3 diff --git a/python/paddle_fl/mpc/framework.py b/python/paddle_fl/mpc/framework.py index 858c2a6b9c3f99c32bdb77d448059bfcfc17aaed..6e33ecc99351209da753a151aa3d97cb3bfd063a 100644 --- a/python/paddle_fl/mpc/framework.py +++ b/python/paddle_fl/mpc/framework.py @@ -21,14 +21,13 @@ from paddle.fluid import core from paddle.fluid import unique_name from paddle.fluid.framework import Variable from paddle.fluid.framework import convert_np_dtype_to_dtype_ - +from paddle.fluid.data_feeder import check_type, check_dtype class MpcVariable(Variable): """ Extends from paddle.fluid.framework.Variable and rewrite the __init__ method where the shape is resized. """ - def __init__(self, block, type=core.VarDesc.VarType.LOD_TENSOR, @@ -91,22 +90,22 @@ class MpcVariable(Variable): else: old_dtype = self.dtype if dtype != old_dtype: - raise ValueError( - "MpcVariable {0} has been created before. " - "The previous data type is {1}; the new " - "data type is {2}. They are not " - "matched.".format(self.name, old_dtype, dtype)) + raise ValueError("MpcVariable {0} has been created before. " + "The previous data type is {1}; the new " + "data type is {2}. They are not " + "matched.".format(self.name, old_dtype, + dtype)) if lod_level is not None: if is_new_var: self.desc.set_lod_level(lod_level) else: if lod_level != self.lod_level: - raise ValueError( - "MpcVariable {0} has been created before. " - "The previous lod_level is {1}; the new " - "lod_level is {2}. They are not " - "matched".format(self.name, self.lod_level, lod_level)) + raise ValueError("MpcVariable {0} has been created before. " + "The previous lod_level is {1}; the new " + "lod_level is {2}. They are not " + "matched".format(self.name, self.lod_level, + lod_level)) if persistable is not None: if is_new_var: self.desc.set_persistable(persistable) @@ -156,8 +155,7 @@ class MpcParameter(MpcVariable): if len(shape) == 0: raise ValueError( - "The dimensions of shape for MpcParameter must be greater than 0" - ) + "The dimensions of shape for MpcParameter must be greater than 0") for each in shape: if each < 0: @@ -175,8 +173,7 @@ class MpcParameter(MpcVariable): **kwargs) self.trainable = kwargs.get('trainable', True) - self.optimize_attr = kwargs.get('optimize_attr', - {'learning_rate': 1.0}) + self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0}) self.regularizer = kwargs.get('regularizer', None) @@ -203,8 +200,8 @@ class MpcParameter(MpcVariable): additional_attr = ("trainable", "optimize_attr", "regularizer", "gradient_clip_attr", "do_model_average") for attr_name in additional_attr: - res_str += "%s: %s\n" % ( - attr_name, cpt.to_text(getattr(self, attr_name))) + res_str += "%s: %s\n" % (attr_name, + cpt.to_text(getattr(self, attr_name))) else: res_str = MpcVariable.to_string(self, throw_on_error, False) return res_str @@ -245,8 +242,7 @@ def create_mpc_parameter(block, *args, **kwargs): init_ops_len = len(init_ops) if init_ops_len > 1: raise RuntimeError("mpc_param " + mpc_param.name + - " is inited by multiple init ops " + str( - init_ops)) + " is inited by multiple init ops " + str(init_ops)) elif init_ops_len == 1: # TODO(Paddle 1.7): already inited, do nothing, should log a warning pass @@ -272,7 +268,6 @@ def create_mpc_var(block, *args, **kwargs): kwargs['initializer'](var, block) return var - def is_mpc_parameter(var): """ Check whether the given variable is an instance of MpcParameter. @@ -282,4 +277,13 @@ def is_mpc_parameter(var): bool: True if the given `var` is an instance of Parameter, False if not. """ - return isinstance(var, MpcParameter) + return type(var) == MpcParameter + +def check_mpc_variable_and_dtype(input, + input_name, + expected_dtype, + op_name, + extra_message=''): + check_type(input, input_name, MpcVariable, op_name, extra_message) + check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message) + diff --git a/python/paddle_fl/mpc/layers/__init__.py b/python/paddle_fl/mpc/layers/__init__.py index 2abf5788deb9d9c1772877d7d3b1b41af49a779e..08b07bc1d81b85c62cbebe7cc3a5edbd9629ceee 100644 --- a/python/paddle_fl/mpc/layers/__init__.py +++ b/python/paddle_fl/mpc/layers/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/paddle_fl/mpc/layers/basic.py b/python/paddle_fl/mpc/layers/basic.py index a61a7d72b6621a526c1482091533369089e947a0..dd84ec0e0ec405e460617d3f19a8ad644bedfcf3 100644 --- a/python/paddle_fl/mpc/layers/basic.py +++ b/python/paddle_fl/mpc/layers/basic.py @@ -14,9 +14,10 @@ """ basic mpc op layers. """ -from paddle.fluid.data_feeder import check_type_and_dtype +from paddle.fluid.data_feeder import check_variable_and_dtype from ..framework import MpcVariable +from ..framework import check_mpc_variable_and_dtype from ..mpc_layer_helper import MpcLayerHelper __all__ = [ @@ -32,8 +33,8 @@ def _elementwise_op(helper): assert x is not None, 'x cannot be None in {}'.format(op_type) assert y is not None, 'y cannot be None in {}'.format(op_type) - check_type_and_dtype(x, 'x', MpcVariable, ['int64'], op_type) - check_type_and_dtype(y, 'y', MpcVariable, ['int64'], op_type) + check_mpc_variable_and_dtype(x, 'x', ['int64'], op_type) + check_mpc_variable_and_dtype(y, 'y', ['int64'], op_type) axis = helper.kwargs.get('axis', -1) use_mkldnn = helper.kwargs.get('use_mkldnn', False) diff --git a/python/paddle_fl/mpc/layers/compare.py b/python/paddle_fl/mpc/layers/compare.py index e7e862200b8e3c7689be3e6f8df257a39eaa8376..e2f1337a4755a1db5ac3de9f881189766e2e6e5d 100644 --- a/python/paddle_fl/mpc/layers/compare.py +++ b/python/paddle_fl/mpc/layers/compare.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,6 @@ """ mpc math compare layers. """ -from paddle.fluid.data_feeder import check_type_and_dtype from ..framework import MpcVariable from ..mpc_layer_helper import MpcLayerHelper diff --git a/python/paddle_fl/mpc/layers/math.py b/python/paddle_fl/mpc/layers/math.py index 9cb06fb423c07320f626cddbe8fd9bea60352b78..7e9d27007c839b59f3a82e3a087eb172963a044f 100644 --- a/python/paddle_fl/mpc/layers/math.py +++ b/python/paddle_fl/mpc/layers/math.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,9 +14,9 @@ """ mpc math op layers. """ -from paddle.fluid.data_feeder import check_type_and_dtype from ..framework import MpcVariable +from ..framework import check_mpc_variable_and_dtype from ..mpc_layer_helper import MpcLayerHelper __all__ = [ @@ -39,7 +39,7 @@ def mean(x, name=None): Examples: todo """ helper = MpcLayerHelper("mean", **locals()) - check_type_and_dtype(x, 'x', MpcVariable, ['int64'], 'mean') + check_mpc_variable_and_dtype(x, 'x', ['int64'], 'mean') if name is None: out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype) else: @@ -64,7 +64,7 @@ def square(x, name=None): Examples: todo """ helper = MpcLayerHelper("square", **locals()) - check_type_and_dtype(x, 'x', MpcVariable, ['int64'], 'square') + check_mpc_variable_and_dtype(x, 'x', ['int64'], 'square') if name is None: out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype) else: @@ -89,8 +89,7 @@ def sum(x): Examples: todo """ helper = MpcLayerHelper("sum", **locals()) - out = helper.create_mpc_variable_for_type_inference( - dtype=helper.input_dtype('x')) + out = helper.create_mpc_variable_for_type_inference(dtype=helper.input_dtype('x')) helper.append_op( type="mpc_sum", inputs={"X": x}, @@ -116,18 +115,16 @@ def square_error_cost(input, label): Examples: todo """ helper = MpcLayerHelper('square_error_cost', **locals()) - minus_out = helper.create_mpc_variable_for_type_inference( - dtype=input.dtype) + minus_out = helper.create_mpc_variable_for_type_inference(dtype=input.dtype) helper.append_op( type='mpc_elementwise_sub', inputs={'X': [input], 'Y': [label]}, outputs={'Out': [minus_out]}) - square_out = helper.create_mpc_variable_for_type_inference( - dtype=input.dtype) + square_out = helper.create_mpc_variable_for_type_inference(dtype=input.dtype) helper.append_op( - type='mpc_square', + type='mpc_square', inputs={'X': [minus_out]}, outputs={'Out': [square_out]}) return square_out diff --git a/python/paddle_fl/mpc/layers/matrix.py b/python/paddle_fl/mpc/layers/matrix.py index 3fffcbb2fccf614bffad1658359e0f90ab041252..f88f1f3fffcd5eb990b5396d8197d19e44e80765 100644 --- a/python/paddle_fl/mpc/layers/matrix.py +++ b/python/paddle_fl/mpc/layers/matrix.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,12 +14,14 @@ """ mpc matrix op layers. """ -from paddle.fluid.data_feeder import check_type_and_dtype from ..framework import MpcVariable +from ..framework import check_mpc_variable_and_dtype from ..mpc_layer_helper import MpcLayerHelper -__all__ = ['mul', ] +__all__ = [ + 'mul', +] def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): @@ -61,13 +63,13 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): inputs = {"X": [x], "Y": [y]} attrs = { - "x_num_col_dims": x_num_col_dims, + "x_num_col_dims": x_num_col_dims, "y_num_col_dims": y_num_col_dims } helper = MpcLayerHelper("mul", **locals()) - check_type_and_dtype(x, 'x', MpcVariable, ['int64'], 'mul') - check_type_and_dtype(y, 'y', MpcVariable, ['int64'], 'mul') + check_mpc_variable_and_dtype(x, 'x', ['int64'], 'mul') + check_mpc_variable_and_dtype(y, 'y', ['int64'], 'mul') if name is None: out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype) else: @@ -75,9 +77,9 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): name=name, dtype=x.dtype, persistable=False) helper.append_op( - type="mpc_mul", - inputs={"X": x, - "Y": y}, - attrs=attrs, + type="mpc_mul", + inputs={"X": x, + "Y": y}, + attrs=attrs, outputs={"Out": out}) return out diff --git a/python/paddle_fl/mpc/layers/ml.py b/python/paddle_fl/mpc/layers/ml.py index 39cf6b90ac2e5bf8a01ae93fab305eed3d435f6d..3781766453de2b96e8837e2cdcc53470161ae842 100644 --- a/python/paddle_fl/mpc/layers/ml.py +++ b/python/paddle_fl/mpc/layers/ml.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,9 +17,9 @@ mpc ml op layers. from functools import reduce from paddle.fluid.data_feeder import check_type, check_dtype -from paddle.fluid.data_feeder import check_type_and_dtype import numpy from ..framework import MpcVariable +from ..framework import check_mpc_variable_and_dtype from ..mpc_layer_helper import MpcLayerHelper __all__ = [ @@ -30,9 +30,6 @@ __all__ = [ ] -# add softmax, relu - - def fc(input, size, num_flatten_dims=1, @@ -186,8 +183,7 @@ def softmax(input, use_cudnn=False, name=None, axis=-1): """ attrs = {"axis": axis, "use_cudnn": use_cudnn} helper = MpcLayerHelper('softmax', **locals()) - check_type_and_dtype(input, 'input', MpcVariable, - ['float16', 'float32', 'float64'], 'softmax') + check_mpc_variable_and_dtype(input, 'input', ['int64'], 'softmax') dtype = helper.input_dtype() mpc_softmax_out = helper.create_mpc_variable_for_type_inference(dtype) @@ -226,7 +222,9 @@ def relu(input, name=None): dtype = helper.input_dtype(input_param_name='input') out = helper.create_mpc_variable_for_type_inference(dtype) helper.append_op( - type="mpc_relu", inputs={"X": input}, outputs={"Y": out}) + type="mpc_relu", + inputs={"X": input}, + outputs={"Y": out}) return out diff --git a/python/paddle_fl/mpc/layers/mpc_math_op_patch.py b/python/paddle_fl/mpc/layers/mpc_math_op_patch.py index 5edb6a9dccdd16b0b8ee279f493780f3cc2e774e..4db3246757bc6b30c02a838205e8714e696158c8 100644 --- a/python/paddle_fl/mpc/layers/mpc_math_op_patch.py +++ b/python/paddle_fl/mpc/layers/mpc_math_op_patch.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,7 +32,6 @@ def monkey_patch_mpc_variable(): Monkey patch for operator overloading. :return: """ - def unique_tmp_name(): """ Generate temp name for variable. @@ -80,7 +79,9 @@ def monkey_patch_mpc_variable(): tmp_name = unique_tmp_name() return block.create_var(name=tmp_name, dtype=dtype) - def _elemwise_method_creator_(method_name, op_type, reverse=False): + def _elemwise_method_creator_(method_name, + op_type, + reverse=False): """ Operator overloading for different method. :param method_name: the name of operator which is overloaded. @@ -88,19 +89,16 @@ def monkey_patch_mpc_variable(): :param reverse: :return: """ - def __impl__(self, other_var): lhs_dtype = safe_get_dtype(self) if method_name in compare_ops: if not isinstance(other_var, Variable): - raise NotImplementedError( - "Unsupported data type of {} for compare operations." - .format(other_var.name)) + raise NotImplementedError("Unsupported data type of {} for compare operations." + .format(other_var.name)) else: if not isinstance(other_var, MpcVariable): - raise NotImplementedError( - "Unsupported data type of {}.".format(other_var.name)) + raise NotImplementedError("Unsupported data type of {}.".format(other_var.name)) rhs_dtype = safe_get_dtype(other_var) if reverse: @@ -111,8 +109,7 @@ def monkey_patch_mpc_variable(): if method_name in compare_ops: out = create_new_tmp_var(current_block(self), dtype=rhs_dtype) else: - out = create_new_tmp_mpc_var( - current_block(self), dtype=lhs_dtype) + out = create_new_tmp_mpc_var(current_block(self), dtype=lhs_dtype) # out = create_new_tmp_mpc_var(current_block(self), dtype=lhs_dtype) @@ -120,9 +117,9 @@ def monkey_patch_mpc_variable(): if other_var.shape[0] == -1: axis = 0 assert len(self.shape) >= len(other_var.shape), ( - "The rank of the first argument of an binary operator cannot " - "be smaller than the rank of its second argument: %s vs %s" % - (len(self.shape), len(other_var.shape))) + "The rank of the first argument of an binary operator cannot " + "be smaller than the rank of its second argument: %s vs %s" % + (len(self.shape), len(other_var.shape))) current_block(self).append_op( type=op_type, @@ -157,32 +154,33 @@ def monkey_patch_mpc_variable(): # inject methods for method_name, op_type, reverse in ( - ("__add__", "mpc_elementwise_add", False), + ("__add__", "mpc_elementwise_add", False), # a+b == b+a. Do not need to reverse explicitly - ("__radd__", "mpc_elementwise_add", False), - ("__sub__", "mpc_elementwise_sub", False), - ("__rsub__", "mpc_elementwise_sub", True), - ("__mul__", "mpc_elementwise_mul", False), + ("__radd__", "mpc_elementwise_add", False), + ("__sub__", "mpc_elementwise_sub", False), + ("__rsub__", "mpc_elementwise_sub", True), + ("__mul__", "mpc_elementwise_mul", False), # a*b == b*a. Do not need to reverse explicitly - ("__rmul__", "mpc_elementwise_mul", False), - ("__div__", "mpc_elementwise_div", False), - ("__truediv__", "mpc_elementwise_div", False), - ("__rdiv__", "mpc_elementwise_div", True), - ("__rtruediv__", "mpc_elementwise_div", True), - ("__pow__", "mpc_elementwise_pow", False), - ("__rpow__", "mpc_elementwise_pow", True), - ("__floordiv__", "mpc_elementwise_floordiv", False), - ("__mod__", "mpc_elementwise_mod", False), + ("__rmul__", "mpc_elementwise_mul", False), + ("__div__", "mpc_elementwise_div", False), + ("__truediv__", "mpc_elementwise_div", False), + ("__rdiv__", "mpc_elementwise_div", True), + ("__rtruediv__", "mpc_elementwise_div", True), + ("__pow__", "mpc_elementwise_pow", False), + ("__rpow__", "mpc_elementwise_pow", True), + ("__floordiv__", "mpc_elementwise_floordiv", False), + ("__mod__", "mpc_elementwise_mod", False), # for logical compare - ("__eq__", "mpc_equal", False), - ("__ne__", "mpc_not_equal", False), - ("__lt__", "mpc_less_than", False), - ("__le__", "mpc_less_equal", False), - ("__gt__", "mpc_greater_than", False), - ("__ge__", "mpc_greater_equal", False)): + ("__eq__", "mpc_equal", False), + ("__ne__", "mpc_not_equal", False), + ("__lt__", "mpc_less_than", False), + ("__le__", "mpc_less_equal", False), + ("__gt__", "mpc_greater_than", False), + ("__ge__", "mpc_greater_equal", False) + ): # Not support computation between MpcVariable and scalar. - setattr(MpcVariable, method_name, + setattr(MpcVariable, + method_name, _elemwise_method_creator_(method_name, op_type, reverse) if method_name in supported_mpc_ops else announce_not_impl) - # MpcVariable.astype = astype diff --git a/python/setup.py b/python/setup.py index 65aaa90e9fc92fc7fd40196010a05aab261f6ab7..c03655a4c2445b2c9a7e0f1b65d75d7d07a5bad0 100644 --- a/python/setup.py +++ b/python/setup.py @@ -34,7 +34,7 @@ def python_version(): max_version, mid_version, min_version = python_version() REQUIRED_PACKAGES = [ - 'six >= 1.10.0', 'protobuf >= 3.1.0', 'paddlepaddle == 1.6.3', + 'six >= 1.10.0', 'protobuf >= 3.1.0', 'paddlepaddle >= 1.8.0', 'paddlepaddle-gpu >= 1.8' ] if max_version < 3: