未验证 提交 a1a9bf6b 编写于 作者: Q Qinghe JING 提交者: GitHub

Merge pull request #70 from jhjiangcs/smc-611

improve code to support PaddlePaddle1.8.0.
...@@ -34,8 +34,8 @@ execute_process(COMMAND ${PYTHON} -c "import paddle;print(paddle.version.full_ve ...@@ -34,8 +34,8 @@ execute_process(COMMAND ${PYTHON} -c "import paddle;print(paddle.version.full_ve
RESULT_VARIABLE ret OUTPUT_VARIABLE paddle_version OUTPUT_STRIP_TRAILING_WHITESPACE) RESULT_VARIABLE ret OUTPUT_VARIABLE paddle_version OUTPUT_STRIP_TRAILING_WHITESPACE)
if (NOT ret) if (NOT ret)
if (NOT ${paddle_version} STREQUAL "1.6.3") if (NOT ${paddle_version} STREQUAL "1.8.0")
message(FATAL_ERROR "Paddle installation of 1.6.3 is required but ${paddle_version} is found") message(FATAL_ERROR "Paddle installation of 1.8.0 is required but ${paddle_version} is found")
endif() endif()
else() else()
message(FATAL_ERROR "Could not get paddle version.") message(FATAL_ERROR "Could not get paddle version.")
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#include "mpc_compare_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "mpc_compare_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -23,85 +23,73 @@ using Tensor = framework::Tensor; ...@@ -23,85 +23,73 @@ using Tensor = framework::Tensor;
class MpcCompareOp : public framework::OperatorWithKernel { class MpcCompareOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx->HasInput("X"), true,
"Input(X) of MpcCompareOp should not be null.")); platform::errors::NotFound("Input(X) of MpcCompareOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx->HasInput("Y"), true,
"Input(Y) of MpcCompareOp should not be null.")); platform::errors::NotFound("Input(Y) of MpcCompareOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx->HasOutput("Out"), true,
"Output(Out) of MpcCompareOp should not be null.")); platform::errors::NotFound("Output(Out) of MpcCompareOp should not be null."));
auto dim_x = ctx->GetInputDim("X"); auto dim_x = ctx->GetInputDim("X");
auto dim_y = ctx->GetInputDim("Y"); auto dim_y = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(dim_x.size(), dim_y.size(), PADDLE_ENFORCE_GE(dim_x.size(), dim_y.size(),
"The size of dim_y should not be greater than dim_x's."); "The size of dim_y should not be greater than dim_x's.");
ctx->ShareDim("Y", /*->*/ "Out"); ctx->ShareDim("Y", /*->*/ "Out");
ctx->ShareLoD("Y", /*->*/ "Out"); ctx->ShareLoD("Y", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
framework::OpKernelType
GetExpectedKernelType(const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
}; };
class MpcCompareOpMaker : public framework::OpProtoAndCheckerMaker { class MpcCompareOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor), The first input tensor of MpcCompareOp."); AddInput("X", "(Tensor), The first input tensor of MpcCompareOp.");
AddInput("Y", "(Tensor), The second input tensor of MpcCompareOp."); AddInput("Y", "(Tensor), The second input tensor of MpcCompareOp.");
AddOutput("Out", "(Tensor), The output tensor of MpcCompareOp."); AddOutput("Out", "(Tensor), The output tensor of MpcCompareOp.");
AddComment(R"DOC( AddComment(R"DOC(
MPC Compare Operator. MPC Compare Operator.
)DOC"); )DOC");
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_than, ops::MpcCompareOp, REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_than, ops::MpcCompareOp, ops::MpcCompareOpMaker);
ops::MpcCompareOpMaker); REGISTER_OP_CPU_KERNEL(mpc_greater_than,
REGISTER_OP_CPU_KERNEL( ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcGreaterThanFunctor>);
mpc_greater_than,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker);
ops::MpcGreaterThanFunctor>); REGISTER_OP_CPU_KERNEL(mpc_greater_equal,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcGreaterEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_equal, ops::MpcCompareOp,
ops::MpcCompareOpMaker); REGISTER_OP_WITHOUT_GRADIENT(mpc_less_than, ops::MpcCompareOp, ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(mpc_less_than,
mpc_greater_equal, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcLessThanFunctor>);
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::MpcGreaterEqualFunctor>); REGISTER_OP_WITHOUT_GRADIENT(mpc_less_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(mpc_less_equal,
REGISTER_OP_WITHOUT_GRADIENT(mpc_less_than, ops::MpcCompareOp, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcLessEqualFunctor>);
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_WITHOUT_GRADIENT(mpc_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker);
mpc_less_than, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, REGISTER_OP_CPU_KERNEL(mpc_equal,
int64_t, ops::MpcLessThanFunctor>); ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_less_equal, ops::MpcCompareOp, REGISTER_OP_WITHOUT_GRADIENT(mpc_not_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker);
ops::MpcCompareOpMaker); REGISTER_OP_CPU_KERNEL(mpc_not_equal,
REGISTER_OP_CPU_KERNEL( ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcNotEqualFunctor>);
mpc_less_equal, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MpcLessEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_equal, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_equal, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MpcEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_not_equal, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_not_equal, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MpcNotEqualFunctor>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License.uage governing permissions and limitations under the License. */
#pragma once #pragma once
#include "mpc_op.h" #include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include <math.h>
#include <type_traits>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -24,58 +21,52 @@ namespace operators { ...@@ -24,58 +21,52 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
struct MpcGreaterThanFunctor { struct MpcGreaterThanFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->gt( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->gt(in_x_t, in_y_t, out_t);
in_x_t, in_y_t, out_t); }
}
}; };
struct MpcGreaterEqualFunctor { struct MpcGreaterEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->geq( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->geq(in_x_t, in_y_t, out_t);
in_x_t, in_y_t, out_t); }
}
}; };
struct MpcLessThanFunctor { struct MpcLessThanFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->lt( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->lt(in_x_t, in_y_t, out_t);
in_x_t, in_y_t, out_t); }
}
}; };
struct MpcLessEqualFunctor { struct MpcLessEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->leq( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->leq(in_x_t, in_y_t, out_t);
in_x_t, in_y_t, out_t); }
}
}; };
struct MpcEqualFunctor { struct MpcEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->eq( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->eq(in_x_t, in_y_t, out_t);
in_x_t, in_y_t, out_t); }
}
}; };
struct MpcNotEqualFunctor { struct MpcNotEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neq( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neq(in_x_t, in_y_t, out_t);
in_x_t, in_y_t, out_t); }
}
}; };
template <typename DeviceContext, typename T, typename Functor> template <typename DeviceContext, typename T, typename Functor>
class MpcCompareOpKernel : public MpcOpKernel<T> { class MpcCompareOpKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override{
auto *in_x_t = ctx.Input<framework::LoDTensor>("X"); auto *in_x_t = ctx.Input<framework::LoDTensor>("X");
auto *in_y_t = ctx.Input<framework::LoDTensor>("Y"); auto *in_y_t = ctx.Input<framework::LoDTensor>("Y");
auto *out_t = ctx.Output<framework::LoDTensor>("Out"); auto *out_t = ctx.Output<framework::LoDTensor>("Out");
auto out = out_t->mutable_data<T>(ctx.GetPlace()); auto out = out_t->mutable_data<T>(ctx.GetPlace());
Functor().Run(in_x_t, in_y_t, out_t); Functor().Run(in_x_t, in_y_t, out_t);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddl } // namespace paddl
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#include "mpc_elementwise_add_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "mpc_elementwise_add_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -22,111 +22,105 @@ using Tensor = framework::Tensor; ...@@ -22,111 +22,105 @@ using Tensor = framework::Tensor;
class MpcElementwiseAddOp : public framework::OperatorWithKernel { class MpcElementwiseAddOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true, ctx->HasInput("X"), true,
platform::errors::NotFound( platform::errors::NotFound("Input(X) of MpcElementwiseAddOp should not be null."));
"Input(X) of MpcElementwiseAddOp should not be null.")); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ( ctx->HasInput("Y"), true,
ctx->HasInput("Y"), true, platform::errors::NotFound("Input(Y) of MpcElementwiseAddOp should not be null."));
platform::errors::NotFound( PADDLE_ENFORCE_EQ(
"Input(Y) of MpcElementwiseAddOp should not be null.")); ctx->HasOutput("Out"), true,
PADDLE_ENFORCE_EQ( platform::errors::NotFound("Output(Out) of MpcElementwiseAddOp should not be null."));
ctx->HasOutput("Out"), true, PADDLE_ENFORCE_GE(
platform::errors::NotFound( ctx->GetInputDim("X").size(), ctx->GetInputDim("Y").size(),
"Output(Out) of MpcElementwiseAddOp should not be null.")); platform::errors::InvalidArgument(
PADDLE_ENFORCE_GE( "The dimensions of X should be greater than the dimensions of Y. "
ctx->GetInputDim("X").size(), ctx->GetInputDim("Y").size(), "But received the dimensions of X is [%s], the dimensions of Y is [%s]",
platform::errors::InvalidArgument(
"The dimensions of X should be greater than the dimensions of Y. "
"But received the dimensions of X is [%s], the dimensions of Y is "
"[%s]",
ctx->GetInputDim("X"), ctx->GetInputDim("Y"))); ctx->GetInputDim("X"), ctx->GetInputDim("Y")));
ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
class MpcElementwiseAddOpMaker : public framework::OpProtoAndCheckerMaker { class MpcElementwiseAddOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", AddInput("X", "(Tensor), The first input tensor of mpc elementwise add op.");
"(Tensor), The first input tensor of mpc elementwise add op."); AddInput("Y", "(Tensor), The second input tensor of mpc elementwise add op.");
AddInput("Y", AddOutput("Out", "(Tensor), The output tensor of mpc elementwise add op.");
"(Tensor), The second input tensor of mpc elementwise add op."); AddAttr<int>("axis",
AddOutput("Out", "(Tensor), The output tensor of mpc elementwise add op.");
AddAttr<int>("axis",
"(int, default -1). If X.dimension != Y.dimension," "(int, default -1). If X.dimension != Y.dimension,"
"Y.dimension must be a subsequence of x.dimension. And axis " "Y.dimension must be a subsequence of x.dimension. And axis "
"is the start dimension index " "is the start dimension index "
"for broadcasting Y onto X. ") "for broadcasting Y onto X. ")
.SetDefault(-1) .SetDefault(-1)
.EqualGreaterThan(-1); .EqualGreaterThan(-1);
AddComment(R"DOC( AddComment(R"DOC(
MPC elementwise add Operator. MPC elementwise add Operator.
)DOC"); )DOC");
} }
}; };
class MpcElementwiseAddGradOp : public framework::OperatorWithKernel { class MpcElementwiseAddGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
auto out_grad_name = framework::GradVarName("Out"); auto out_grad_name = framework::GradVarName("Out");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true, PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true,
"Input(Out@GRAD) should not be null."); "Input(Out@GRAD) should not be null.");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
ctx->ShareDim("X", /*->*/ x_grad_name); ctx->ShareDim("X", /*->*/ x_grad_name);
ctx->ShareLoD("X", /*->*/ x_grad_name); ctx->ShareLoD("X", /*->*/ x_grad_name);
} }
if (ctx->HasOutput(y_grad_name)) { if (ctx->HasOutput(y_grad_name)) {
ctx->ShareDim("Y", /*->*/ y_grad_name); ctx->ShareDim("Y", /*->*/ y_grad_name);
ctx->ShareLoD("Y", /*->*/ y_grad_name); ctx->ShareLoD("Y", /*->*/ y_grad_name);
}
} }
}
}; };
template <typename T> template <typename T>
class MpcElementwiseAddOpGradMaker : public framework::SingleGradOpDescMaker { class MpcElementwiseAddOpGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad) const override {
std::unique_ptr<T> retv(new T()); grad->SetType("mpc_elementwise_add_grad");
retv->SetType("mpc_elementwise_add_grad"); grad->SetInput("X", this->Input("X"));
retv->SetInput("X", this->Input("X")); grad->SetInput("Y", this->Input("Y"));
retv->SetInput("Y", this->Input("Y")); grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); grad->SetAttrMap(this->Attrs());
retv->SetAttrMap(this->Attrs()); }
return retv;
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_elementwise_add, ops::MpcElementwiseAddOp, REGISTER_OPERATOR(mpc_elementwise_add, ops::MpcElementwiseAddOp,
ops::MpcElementwiseAddOpMaker, ops::MpcElementwiseAddOpMaker,
ops::MpcElementwiseAddOpGradMaker<paddle::framework::OpDesc>); ops::MpcElementwiseAddOpGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_elementwise_add_grad, ops::MpcElementwiseAddGradOp); REGISTER_OPERATOR(mpc_elementwise_add_grad, ops::MpcElementwiseAddGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mpc_elementwise_add, mpc_elementwise_add,
ops::MpcElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::MpcElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_elementwise_add_grad, REGISTER_OP_CPU_KERNEL(
ops::MpcElementwiseAddGradKernel< mpc_elementwise_add_grad,
paddle::platform::CPUDeviceContext, int64_t>); ops::MpcElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
// This op is different with elementwise_add of PaddlePaddle. // This op is different with elementwise_add of PaddlePaddle.
// We only consider that the dimensions of X is equal with the dimensions of Y. // We only consider that the dimensions of X is equal with the dimensions of Y.
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#pragma once #pragma once
#include "mpc_op.h" #include "mpc_op.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,189 +25,187 @@ namespace operators { ...@@ -26,189 +25,187 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
// paddle/fluid/operators/elementwise/elementwise_op_function.h // paddle/fluid/operators/elementwise/elementwise_op_function.h
template <typename T, typename DeviceContext> class RowwiseTransformIterator; template <typename T, typename DeviceContext>
class RowwiseTransformIterator;
template <typename T> template <typename T>
class RowwiseTransformIterator<T, platform::CPUDeviceContext> class RowwiseTransformIterator<T, platform::CPUDeviceContext>
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t, : public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t, T *, T &> {
T *, T &> {
public: public:
RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {} RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator++() { RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator++() {
++i_; ++i_;
if (UNLIKELY(i_ == n_)) { if (UNLIKELY(i_ == n_)) {
i_ = 0; i_ = 0;
}
return *this;
} }
return *this;
}
RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator+(int n) { RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator+(int n) {
while (n-- > 0) { while (n-- > 0) {
++i_; ++i_;
if (UNLIKELY(i_ == n_)) { if (UNLIKELY(i_ == n_)) {
i_ = 0; i_ = 0;
} }
} }
return *this; return *this;
} }
bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext> bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext> &rhs) const {
&rhs) const { return (ptr_ + i_) == &(*rhs);
return (ptr_ + i_) == &(*rhs); }
}
bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext> bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext> &rhs) const {
&rhs) const { return (ptr_ + i_) != &(*rhs);
return (ptr_ + i_) != &(*rhs); }
}
const T &operator*() { return ptr_[i_]; } const T &operator*() { return ptr_[i_]; }
private: private:
const T *ptr_; const T *ptr_;
int i_; int i_;
int64_t n_; int64_t n_;
}; };
template <typename T> struct AddFunctor { template <typename T>
inline HOSTDEVICE T operator()(T x, T y) { return x + y; } struct AddFunctor {
inline HOSTDEVICE T operator()(T x, T y) { return x + y; }
}; };
struct GetMidDims { struct GetMidDims {
inline HOSTDEVICE void operator()(const framework::DDim &x_dims, inline HOSTDEVICE void operator()(const framework::DDim &x_dims,
const framework::DDim &y_dims, const framework::DDim &y_dims, const int axis,
const int axis, int *pre, int *n, int *pre, int *n, int *post) {
int *post) { *pre = 1;
*pre = 1; *n = 1;
*n = 1; *post = 1;
*post = 1; for (int i = 1; i < axis + 1; ++i) {
for (int i = 1; i < axis + 1; ++i) { (*pre) *= x_dims[i];
(*pre) *= x_dims[i]; }
}
for (int i = 1; i < y_dims.size(); ++i) { for (int i = 1; i < y_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i], PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
"Broadcast dimension mismatch."); "Broadcast dimension mismatch.");
(*n) *= y_dims[i]; (*n) *= y_dims[i];
} }
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i]; (*post) *= x_dims[i];
}
} }
}
}; };
const size_t SHARE_NUM = 2; const size_t SHARE_NUM = 2;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcElementwiseAddKernel : public MpcOpKernel<T> { class MpcElementwiseAddKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override{
auto *in_x_t = ctx.Input<Tensor>("X"); auto *in_x_t = ctx.Input<framework::LoDTensor>("X");
auto *in_y_t = ctx.Input<Tensor>("Y"); auto *in_y_t = ctx.Input<framework::LoDTensor>("Y");
auto *out_t = ctx.Output<Tensor>("Out"); auto *out_t = ctx.Output<framework::LoDTensor>("Out");
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
auto out = out_t->mutable_data<T>(ctx.GetPlace()); auto out = out_t->mutable_data<T>(ctx.GetPlace());
if (in_x_t->dims() == in_y_t->dims()) { if (in_x_t->dims() == in_y_t->dims()) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(in_x_t, in_y_t, out_t);
in_x_t, in_y_t, out_t); } else {
} else { Tensor in_x_t_slice;
Tensor in_x_t_slice; Tensor in_y_t_slice;
Tensor in_y_t_slice; Tensor out_t_slice;
Tensor out_t_slice;
for (size_t i = 0; i < SHARE_NUM; ++i) {
for (size_t i = 0; i < SHARE_NUM; ++i) { in_x_t_slice = in_x_t->Slice(i, i + 1);
in_x_t_slice = in_x_t->Slice(i, i + 1); in_y_t_slice = in_y_t->Slice(i, i + 1);
in_y_t_slice = in_y_t->Slice(i, i + 1); out_t_slice = out_t->Slice(i, i + 1);
out_t_slice = out_t->Slice(i, i + 1);
auto x_dims = in_x_t_slice.dims();
auto x_dims = in_x_t_slice.dims(); auto y_dims = in_y_t_slice.dims();
auto y_dims = in_y_t_slice.dims();
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), "Axis should be in range [0, x_dims)");
"Axis should be in range [0, x_dims)");
int pre, n, post;
int pre, n, post; GetMidDims get_mid_dims;
GetMidDims get_mid_dims; get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); PADDLE_ENFORCE_EQ(post, 1,
PADDLE_ENFORCE_EQ( "post should be equal 1, but received post is [%s]", post);
post, 1, "post should be equal 1, but received post is [%s]", post);
auto x_ = in_x_t_slice.data<T>();
auto x_ = in_x_t_slice.data<T>(); auto y_ = in_y_t_slice.data<T>();
auto y_ = in_y_t_slice.data<T>(); auto out_ = out_t_slice.data<T>();
auto out_ = out_t_slice.data<T>(); auto nx_ = in_x_t_slice.numel();
auto nx_ = in_x_t_slice.numel(); paddle::platform::Transform<DeviceContext> trans;
paddle::platform::Transform<DeviceContext> trans; trans(ctx.template device_context<DeviceContext>(), x_, x_ + nx_,
trans(ctx.template device_context<DeviceContext>(), x_, x_ + nx_, RowwiseTransformIterator<T, DeviceContext>(y_, n),
RowwiseTransformIterator<T, DeviceContext>(y_, n), out_, out_, AddFunctor<T>());
AddFunctor<T>()); }
} }
}
} }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcElementwiseAddGradKernel : public MpcOpKernel<T> { class MpcElementwiseAddGradKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<framework::LoDTensor>("X"); auto *in_x_t = ctx.Input<framework::LoDTensor>("X");
auto *in_y_t = ctx.Input<framework::LoDTensor>("Y"); auto *in_y_t = ctx.Input<framework::LoDTensor>("Y");
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto *dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
auto dout_data = dout->data<T>(); auto dout_data = dout->data<T>();
if (dx) { if (dx) {
auto dx_data = dx->mutable_data<T>(ctx.GetPlace()); auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < dout->numel(); i++) { for (size_t i = 0; i < dout->numel(); i++) {
dx_data[i] = dout_data[i]; dx_data[i] = dout_data[i];
} }
}
if (dy) {
auto dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (in_x_t->dims().size() == in_y_t->dims().size()) {
for (size_t i = 0; i < dout->numel(); i++) {
dy_data[i] = dout_data[i];
} }
} else {
auto x_dims = in_x_t->dims(); if (dy) {
auto y_dims = in_y_t->dims(); auto dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (in_x_t->dims().size() == in_y_t->dims().size()) {
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); for (size_t i = 0; i < dout->numel(); i++) {
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), dy_data[i] = dout_data[i];
"Axis should be in range [0, x_dims)"); }
} else {
int pre, n, post; auto x_dims = in_x_t->dims();
GetMidDims get_mid_dims; auto y_dims = in_y_t->dims();
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
PADDLE_ENFORCE_EQ( axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
post, 1, "post should be equal 1, but received post is [%s]", post); PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
for (size_t i = 0; i < SHARE_NUM; ++i) {
int y_offset = i * n; int pre, n, post;
for (size_t j = 0; j < pre; ++j) { GetMidDims get_mid_dims;
for (size_t k = 0; k < n; ++k) { get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
int out_offset = i * pre * n + j * n + k; PADDLE_ENFORCE_EQ(post, 1,
if (0 == j) { "post should be equal 1, but received post is [%s]", post);
dy_data[k + y_offset] = dout_data[out_offset];
} else { for (size_t i = 0; i < SHARE_NUM; ++i) {
dy_data[k + y_offset] += dout_data[out_offset]; int y_offset = i * n;
} for (size_t j = 0; j < pre; ++j) {
for (size_t k = 0; k < n; ++k) {
int out_offset = i * pre * n + j * n + k;
if (0 == j) {
dy_data[k + y_offset] = dout_data[out_offset];
} else {
dy_data[k + y_offset] += dout_data[out_offset];
}
}
}
}
} }
}
} }
}
} }
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#include "mpc_elementwise_sub_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "mpc_elementwise_sub_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class MpcElementwiseSubOp : public framework::OperatorWithKernel { class MpcElementwiseSubOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound(
"Input(Y) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("X"), ctx->GetInputDim("Y"),
platform::errors::InvalidArgument(
"The dimensions of X should be equal with the dimensions of Y. "
"But received the dimensions of X is [%s], the dimensions of Y is "
"[%s]",
ctx->GetInputDim("X"), ctx->GetInputDim("Y")));
ctx->ShareDim("X", /*->*/ "Out"); void InferShape(framework::InferShapeContext* ctx) const override {
ctx->ShareLoD("X", /*->*/ "Out"); PADDLE_ENFORCE_EQ(
} ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound("Input(Y) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("X"), ctx->GetInputDim("Y"),
platform::errors::InvalidArgument(
"The dimensions of X should be equal with the dimensions of Y. "
"But received the dimensions of X is [%s], the dimensions of Y is [%s]",
ctx->GetInputDim("X"), ctx->GetInputDim("Y")));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
}; };
class MpcElementwiseSubOpMaker : public framework::OpProtoAndCheckerMaker { class MpcElementwiseSubOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", AddInput("X", "(Tensor), The first input tensor of mpc elementwise sub op.");
"(Tensor), The first input tensor of mpc elementwise sub op."); AddInput("Y", "(Tensor), The second input tensor of mpc elementwise sub op.");
AddInput("Y", AddOutput("Out", "(Tensor), The output tensor of mpc elementwise sub op.");
"(Tensor), The second input tensor of mpc elementwise sub op."); AddComment(R"DOC(
AddOutput("Out", "(Tensor), The output tensor of mpc elementwise sub op.");
AddComment(R"DOC(
MPC elementwise sub Operator. MPC elementwise sub Operator.
)DOC"); )DOC");
} }
}; };
class MpcElementwiseSubGradOp : public framework::OperatorWithKernel { class MpcElementwiseSubGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
auto out_grad_name = framework::GradVarName("Out"); auto out_grad_name = framework::GradVarName("Out");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true, PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true,
"Input(Out@GRAD) should not be null."); "Input(Out@GRAD) should not be null.");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
ctx->ShareDim("X", /*->*/ x_grad_name); ctx->ShareDim("X", /*->*/ x_grad_name);
ctx->ShareLoD("X", /*->*/ x_grad_name); ctx->ShareLoD("X", /*->*/ x_grad_name);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->ShareDim("Y", /*->*/ y_grad_name);
ctx->ShareLoD("Y", /*->*/ y_grad_name);
}
} }
if (ctx->HasOutput(y_grad_name)) {
ctx->ShareDim("Y", /*->*/ y_grad_name);
ctx->ShareLoD("Y", /*->*/ y_grad_name);
}
}
}; };
template <typename T> template <typename T>
class MpcElementwiseSubGradMaker : public framework::SingleGradOpDescMaker { class MpcElementwiseSubGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad) const override {
std::unique_ptr<T> retv(new T()); grad->SetType("mpc_elementwise_sub_grad");
retv->SetType("mpc_elementwise_sub_grad"); grad->SetInput("X", this->Input("X"));
retv->SetInput("X", this->Input("X")); grad->SetInput("Y", this->Input("Y"));
retv->SetInput("Y", this->Input("Y")); grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); grad->SetAttrMap(this->Attrs());
retv->SetAttrMap(this->Attrs()); }
return retv;
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_elementwise_sub, ops::MpcElementwiseSubOp, REGISTER_OPERATOR(mpc_elementwise_sub, ops::MpcElementwiseSubOp,
ops::MpcElementwiseSubOpMaker, ops::MpcElementwiseSubOpMaker,
ops::MpcElementwiseSubGradMaker<paddle::framework::OpDesc>); ops::MpcElementwiseSubGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_elementwise_sub_grad, ops::MpcElementwiseSubGradOp); REGISTER_OPERATOR(mpc_elementwise_sub_grad, ops::MpcElementwiseSubGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mpc_elementwise_sub, mpc_elementwise_sub,
ops::MpcElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::MpcElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_elementwise_sub_grad, REGISTER_OP_CPU_KERNEL(
ops::MpcElementwiseSubGradKernel< mpc_elementwise_sub_grad,
paddle::platform::CPUDeviceContext, int64_t>); ops::MpcElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
// This op is different with elementwise_sub of PaddlePaddle. // This op is different with elementwise_sub of PaddlePaddle.
// We only consider that the dimensions of X is equal with the dimensions of Y. // We only consider that the dimensions of X is equal with the dimensions of Y.
#pragma once #pragma once
#include "mpc_op.h" #include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -27,40 +26,39 @@ using Tensor = framework::Tensor; ...@@ -27,40 +26,39 @@ using Tensor = framework::Tensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcElementwiseSubKernel : public MpcOpKernel<T> { class MpcElementwiseSubKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<Tensor>("X"); auto *in_x_t = ctx.Input<Tensor>("X");
auto *in_y_t = ctx.Input<Tensor>("Y"); auto *in_y_t = ctx.Input<Tensor>("Y");
auto *out_t = ctx.Output<Tensor>("Out"); auto *out_t = ctx.Output<Tensor>("Out");
auto out = out_t->mutable_data<T>(ctx.GetPlace()); auto out = out_t->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(in_x_t, in_y_t, out_t);
in_x_t, in_y_t, out_t); }
}
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcElementwiseSubGradKernel : public MpcOpKernel<T> { class MpcElementwiseSubGradKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override {
VLOG(3) << "******** MpcElementwiseSubGradKernel: "; VLOG(3) << "******** MpcElementwiseSubGradKernel: ";
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto dout_data = dout->data<T>(); auto dout_data = dout->data<T>();
if (dx) { if (dx) {
auto dx_data = dx->mutable_data<T>(ctx.GetPlace()); auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < dout->numel(); i++) { for (size_t i = 0; i < dout->numel(); i++) {
dx_data[i] = dout_data[i]; dx_data[i] = dout_data[i];
} }
} }
if (dy) { if (dy) {
dy->mutable_data<T>(ctx.GetPlace()); dy->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neg( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neg(dout, dy);
dout, dy); }
} }
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
// Description: // Description:
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_config.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h" #include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_config.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,59 +26,63 @@ using mpc::Aby3Config; ...@@ -26,59 +26,63 @@ using mpc::Aby3Config;
class MpcInitOp : public framework::OperatorBase { class MpcInitOp : public framework::OperatorBase {
public: public:
MpcInitOp(const std::string &type, const framework::VariableNameMap &inputs, MpcInitOp(const std::string& type,
const framework::VariableNameMap &outputs, const framework::VariableNameMap& inputs,
const framework::AttributeMap &attrs) const framework::VariableNameMap& outputs,
: OperatorBase(type, inputs, outputs, attrs) {} const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
auto protocol_name = Attr<std::string>("protocol_name"); auto protocol_name = Attr<std::string>("protocol_name");
auto role = Attr<int>("role"); auto role = Attr<int>("role");
auto local_addr = Attr<std::string>("local_addr"); auto local_addr = Attr<std::string>("local_addr");
auto net_server_addr = Attr<std::string>("net_server_addr"); auto net_server_addr = Attr<std::string>("net_server_addr");
auto net_server_port = Attr<int>("net_server_port"); auto net_server_port = Attr<int>("net_server_port");
MpcConfig _mpc_config; MpcConfig _mpc_config;
_mpc_config.set_int(Aby3Config::ROLE, role); _mpc_config.set_int(Aby3Config::ROLE, role);
_mpc_config.set(Aby3Config::LOCAL_ADDR, local_addr); _mpc_config.set(Aby3Config::LOCAL_ADDR, local_addr);
_mpc_config.set(Aby3Config::NET_SERVER_ADDR, net_server_addr); _mpc_config.set(Aby3Config::NET_SERVER_ADDR, net_server_addr);
_mpc_config.set_int(Aby3Config::NET_SERVER_PORT, net_server_port); _mpc_config.set_int(Aby3Config::NET_SERVER_PORT, net_server_port);
mpc::MpcInstance::init_instance(protocol_name, _mpc_config); mpc::MpcInstance::init_instance(protocol_name, _mpc_config);
} }
}; };
class MpcInitOpMaker : public framework::OpProtoAndCheckerMaker { class MpcInitOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddComment(R"DOC( AddComment(R"DOC(
Where2 Operator. Where2 Operator.
)DOC"); )DOC");
AddAttr<std::string>("protocol_name", "(string , default aby3)" AddAttr<std::string>("protocol_name",
"protocol name") "(string , default aby3)"
"protocol name")
.SetDefault({"aby3"}); .SetDefault({"aby3"});
AddAttr<int>("role", "trainer role.").SetDefault(0); AddAttr<int>("role", "trainer role.").SetDefault(0);
AddAttr<std::string>("local_addr", "(string, default localhost)" AddAttr<std::string>("local_addr",
"local addr") "(string, default localhost)"
"local addr")
.SetDefault({"localhost"}); .SetDefault({"localhost"});
AddAttr<std::string>("net_server_addr", "(string, default localhost)" AddAttr<std::string>("net_server_addr",
"net server addr") "(string, default localhost)"
"net server addr")
.SetDefault({"localhost"}); .SetDefault({"localhost"});
AddAttr<int>("net_server_port", "net server port, default to 6539.") AddAttr<int>("net_server_port", "net server port, default to 6539.").SetDefault(6539);
.SetDefault(6539); }
}
}; };
class MpcInitOpShapeInference : public framework::InferShapeBase { class MpcInitOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *ctx) const override {} void operator()(framework::InferShapeContext* ctx) const override {}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_init, ops::MpcInitOp, ops::MpcInitOpMaker, REGISTER_OPERATOR(
ops::MpcInitOpShapeInference); mpc_init, ops::MpcInitOp,
ops::MpcInitOpMaker, ops::MpcInitOpShapeInference);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#include "mpc_mean_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "mpc_mean_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -22,78 +22,80 @@ using Tensor = framework::Tensor; ...@@ -22,78 +22,80 @@ using Tensor = framework::Tensor;
class MpcMeanOp : public framework::OperatorWithKernel { class MpcMeanOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx->HasInput("X"), true,
"Input(X) of MpcMeanOp should not be null.")); platform::errors::NotFound("Input(X) of MpcMeanOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx->HasOutput("Out"), true,
"Output(Out) of MpcMeanOp should not be null.")); platform::errors::NotFound("Output(Out) of MpcMeanOp should not be null."));
ctx->SetOutputDim("Out", {2, 1}); ctx->SetOutputDim("Out", {2, 1});
} }
}; };
class MpcMeanOpMaker : public framework::OpProtoAndCheckerMaker { class MpcMeanOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor), The first input tensor of mpc mean op."); AddInput("X", "(Tensor), The first input tensor of mpc mean op.");
AddOutput("Out", "(Tensor), The output tensor of mpc mean op."); AddOutput("Out", "(Tensor), The output tensor of mpc mean op.");
AddComment(R"DOC( AddComment(R"DOC(
MPC mean Operator calculates the mean of all elements in X. MPC mean Operator calculates the mean of all elements in X.
)DOC"); )DOC");
} }
}; };
class MpcMeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { class MpcMeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
GetInputOutputWithSameType() const override { const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}}; static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
} return m;
}
}; };
class MpcMeanGradOp : public framework::OperatorWithKernel { class MpcMeanGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X"));
}
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X"));
}
}; };
template <typename T> template <typename T>
class MpcMeanOpGradMaker : public framework::SingleGradOpDescMaker { class MpcMeanOpGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad) const override {
std::unique_ptr<T> retv(new T()); grad->SetType("mpc_mean_grad");
retv->SetType("mpc_mean_grad"); grad->SetInput("X", this->Input("X"));
retv->SetInput("X", this->Input("X")); grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); }
return retv;
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_mean, ops::MpcMeanOp, ops::MpcMeanOpMaker, REGISTER_OPERATOR(mpc_mean, ops::MpcMeanOp,
ops::MpcMeanOpInferVarType, ops::MpcMeanOpMaker,
ops::MpcMeanOpGradMaker<paddle::framework::OpDesc>); ops::MpcMeanOpInferVarType,
ops::MpcMeanOpGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_mean_grad, ops::MpcMeanGradOp); REGISTER_OPERATOR(mpc_mean_grad, ops::MpcMeanGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mpc_mean, ops::MpcMeanKernel<paddle::platform::CPUDeviceContext, int64_t>); mpc_mean,
ops::MpcMeanKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mpc_mean_grad, mpc_mean_grad,
ops::MpcMeanGradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::MpcMeanGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#pragma once #pragma once
#include "mpc_op.h" #include "mpc_op.h"
#include "paddle/fluid/framework/eigen.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -28,43 +27,40 @@ using EigenVector = framework::EigenVector<T, MajorType, IndexType>; ...@@ -28,43 +27,40 @@ using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcMeanKernel : public MpcOpKernel<T> { class MpcMeanKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<Tensor>("X"); auto *in_x_t = ctx.Input<Tensor>("X");
auto *out_t = ctx.Output<Tensor>("Out"); auto *out_t = ctx.Output<Tensor>("Out");
out_t->mutable_data<T>(ctx.GetPlace()); out_t->mutable_data<T>(ctx.GetPlace());
double scale = 1.0 / (in_x_t->numel() / 2.0); double scale = 1.0 / (in_x_t->numel() / 2.0);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sum( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sum(in_x_t, out_t);
in_x_t, out_t); mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(out_t, scale, out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale( }
out_t, scale, out_t);
}
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcMeanGradKernel : public MpcOpKernel<T> { class MpcMeanGradKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(dout->numel() == 2, PADDLE_ENFORCE(dout->numel() == 2, "numel of MpcMean Gradient should be 2.");
"numel of MpcMean Gradient should be 2."); auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto dout_data = dout->data<T>();
auto dout_data = dout->data<T>();
if (dx) {
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < dx->numel() / 2; ++i) {
dx_data[i] = dout_data[0];
}
for (size_t i = dx->numel() / 2; i < dx->numel(); ++i) {
dx_data[i] = dout_data[1];
}
double scale_factor = 1.0 / (dx->numel() / 2); if (dx) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale( auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
dx, scale_factor, dx); for (size_t i = 0; i < dx->numel() / 2; ++i) {
dx_data[i] = dout_data[0];
}
for (size_t i = dx->numel() / 2; i < dx->numel(); ++i) {
dx_data[i] = dout_data[1];
}
double scale_factor = 1.0 / (dx->numel() / 2);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(dx, scale_factor, dx);
}
} }
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#include "mpc_mul_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "mpc_mul_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -22,98 +22,98 @@ using Tensor = framework::Tensor; ...@@ -22,98 +22,98 @@ using Tensor = framework::Tensor;
class MpcMulOp : public framework::OperatorWithKernel { class MpcMulOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx->HasInput("X"), true,
"Input(X) of Mpc MulOp should not be null.")); platform::errors::NotFound("Input(X) of Mpc MulOp should not be null."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true, ctx->HasInput("Y"), true,
platform::errors::NotFound("Input(Y) of MpcMulOp should not be null.")); platform::errors::NotFound("Input(Y) of MpcMulOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx->HasOutput("Out"), true,
"Output(Out) of MpcMulOp should not be null.")); platform::errors::NotFound("Output(Out) of MpcMulOp should not be null."));
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
int x_num_col_dims = ctx->Attrs().Get<int>("x_num_col_dims"); int x_num_col_dims = ctx->Attrs().Get<int>("x_num_col_dims");
int y_num_col_dims = ctx->Attrs().Get<int>("y_num_col_dims"); int y_num_col_dims = ctx->Attrs().Get<int>("y_num_col_dims");
VLOG(3) << "mpc mul operator x.shape=" << x_dims << " y.shape=" << y_dims VLOG(3) << "mpc mul operator x.shape=" << x_dims << " y.shape=" << y_dims
<< " x_num_col_dims=" << x_num_col_dims << " x_num_col_dims=" << x_num_col_dims
<< " y_num_col_dims=" << y_num_col_dims; << " y_num_col_dims=" << y_num_col_dims;
PADDLE_ENFORCE_NE(framework::product(y_dims), 0, PADDLE_ENFORCE_NE(framework::product(y_dims), 0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The Input variable Y(%s) has not " "The Input variable Y(%s) has not "
"been initialized. You may need to confirm " "been initialized. You may need to confirm "
"if you put exe.run(startup_program) " "if you put exe.run(startup_program) "
"after optimizer.minimize function.", "after optimizer.minimize function.",
ctx->Inputs("Y").front())); ctx->Inputs("Y").front()));
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
x_dims.size(), x_num_col_dims, x_dims.size(), x_num_col_dims,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The input tensor X's dimensions of MpcMulOp " "The input tensor X's dimensions of MpcMulOp "
"should be larger than x_num_col_dims. But received X's " "should be larger than x_num_col_dims. But received X's "
"dimensions = %d, X's shape = [%s], x_num_col_dims = %d.", "dimensions = %d, X's shape = [%s], x_num_col_dims = %d.",
x_dims.size(), x_dims, x_num_col_dims)); x_dims.size(), x_dims, x_num_col_dims));
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
y_dims.size(), y_num_col_dims, y_dims.size(), y_num_col_dims,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The input tensor Y's dimensions of MpcMulOp " "The input tensor Y's dimensions of MpcMulOp "
"should be larger than y_num_col_dims. But received Y's " "should be larger than y_num_col_dims. But received Y's "
"dimensions = %d, Y's shape = [%s], y_num_col_dims = %d.", "dimensions = %d, Y's shape = [%s], y_num_col_dims = %d.",
y_dims.size(), y_dims, y_num_col_dims)); y_dims.size(), y_dims, y_num_col_dims));
int x_mat_width = 1; int x_mat_width = 1;
int y_mat_height = 1; int y_mat_height = 1;
for (size_t i = x_num_col_dims + 1; i < x_dims.size(); i++) { for (size_t i = x_num_col_dims + 1; i < x_dims.size(); i++) {
x_mat_width *= x_dims[i]; x_mat_width *= x_dims[i];
} }
for (size_t i = 1; i <= y_num_col_dims; i++) { for (size_t i = 1; i <= y_num_col_dims; i++) {
y_mat_height *= y_dims[i]; y_mat_height *= y_dims[i];
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_mat_width, y_mat_height, x_mat_width, y_mat_height,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"After flatten the input tensor X and Y to 2-D dimensions " "After flatten the input tensor X and Y to 2-D dimensions "
"matrix X1 and Y1, the matrix X1's width must be equal with matrix " "matrix X1 and Y1, the matrix X1's width must be equal with matrix "
"Y1's height. But received X's shape = [%s], X1's " "Y1's height. But received X's shape = [%s], X1's "
"width = %s; Y's shape = [%s], Y1's height = %s.", "width = %s; Y's shape = [%s], Y1's height = %s.",
x_dims, x_mat_width, y_dims, y_mat_height)); x_dims, x_mat_width, y_dims, y_mat_height));
std::vector<int64_t> output_dims; std::vector<int64_t> output_dims;
output_dims.reserve(static_cast<size_t>(1 + x_num_col_dims + y_dims.size() - output_dims.reserve(
y_num_col_dims)); static_cast<size_t>(1 + x_num_col_dims + y_dims.size() - y_num_col_dims));
for (int i = 0; i <= x_num_col_dims; ++i) { // i=0, batch_size (share id) for (int i = 0; i <= x_num_col_dims; ++i) { // i=0, batch_size (share id)
output_dims.push_back(x_dims[i]); output_dims.push_back(x_dims[i]);
}
for (int i = y_num_col_dims + 1; i < y_dims.size(); ++i) {
output_dims.push_back(y_dims[i]);
}
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
ctx->ShareLoD("X", /*->*/ "Out");
} }
for (int i = y_num_col_dims + 1; i < y_dims.size(); ++i) {
output_dims.push_back(y_dims[i]);
}
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
ctx->ShareLoD("X", /*->*/ "Out");
}
}; };
class MpcMulOpMaker : public framework::OpProtoAndCheckerMaker { class MpcMulOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor), The first input tensor of mpc mul op."); AddInput("X", "(Tensor), The first input tensor of mpc mul op.");
AddInput("Y", "(Tensor), The second input tensor of mpc mul op."); AddInput("Y", "(Tensor), The second input tensor of mpc mul op.");
AddOutput("Out", "(Tensor), The output tensor of mpc mul op."); AddOutput("Out", "(Tensor), The output tensor of mpc mul op.");
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<int>( AddAttr<int>(
"x_num_col_dims", "x_num_col_dims",
R"DOC((int, default 1), The mul_op can take tensors with more than two R"DOC((int, default 1), The mul_op can take tensors with more than two
dimensions as its inputs. If the input $X$ is a tensor with more dimensions as its inputs. If the input $X$ is a tensor with more
than two dimensions, $X$ will be flattened into a two-dimensional than two dimensions, $X$ will be flattened into a two-dimensional
matrix first. The flattening rule is: the first `num_col_dims` matrix first. The flattening rule is: the first `num_col_dims`
...@@ -129,109 +129,112 @@ public: ...@@ -129,109 +129,112 @@ public:
Thus, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = Thus, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] =
[24, 30]. [24, 30].
)DOC") )DOC")
.SetDefault(1) .SetDefault(1)
.EqualGreaterThan(1); .EqualGreaterThan(1);
AddAttr<int>( AddAttr<int>(
"y_num_col_dims", "y_num_col_dims",
R"DOC((int, default 1), The mul_op can take tensors with more than two, R"DOC((int, default 1), The mul_op can take tensors with more than two,
dimensions as its inputs. If the input $Y$ is a tensor with more dimensions as its inputs. If the input $Y$ is a tensor with more
than two dimensions, $Y$ will be flattened into a two-dimensional than two dimensions, $Y$ will be flattened into a two-dimensional
matrix first. The attribute `y_num_col_dims` determines how $Y$ is matrix first. The attribute `y_num_col_dims` determines how $Y$ is
flattened. See comments of `x_num_col_dims` for more details. flattened. See comments of `x_num_col_dims` for more details.
)DOC") )DOC")
.SetDefault(1) .SetDefault(1)
.EqualGreaterThan(1); .EqualGreaterThan(1);
AddAttr<float>( AddAttr<float>(
"scale_x", "scale_x",
"scale_x to be used for int8 mul input data x. scale_x has the" "scale_x to be used for int8 mul input data x. scale_x has the"
"same purpose as scale_in in OPs that support quantization." "same purpose as scale_in in OPs that support quantization."
"Only to be used with MKL-DNN INT8") "Only to be used with MKL-DNN INT8")
.SetDefault(1.0f); .SetDefault(1.0f);
AddAttr<std::vector<float>>( AddAttr<std::vector<float>>(
"scale_y", "scale_y",
"scale_y to be used for int8 mul input data y. scale_y has the" "scale_y to be used for int8 mul input data y. scale_y has the"
"same purpose as scale_weights in OPs that support quantization." "same purpose as scale_weights in OPs that support quantization."
"Only to be used with MKL-DNN INT8") "Only to be used with MKL-DNN INT8")
.SetDefault({1.0f}); .SetDefault({1.0f});
AddAttr<float>("scale_out", "scale_out to be used for int8 output data." AddAttr<float>("scale_out",
"Only used with MKL-DNN INT8") "scale_out to be used for int8 output data."
.SetDefault(1.0f); "Only used with MKL-DNN INT8")
AddAttr<bool>( .SetDefault(1.0f);
"force_fp32_output", AddAttr<bool>(
"(bool, default false) Force quantize kernel output FP32, only " "force_fp32_output",
"used in quantized MKL-DNN.") "(bool, default false) Force quantize kernel output FP32, only "
.SetDefault(false); "used in quantized MKL-DNN.")
AddComment(R"DOC( .SetDefault(false);
AddComment(R"DOC(
MPC mul Operator. MPC mul Operator.
)DOC"); )DOC");
} }
}; };
class MpcMulOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { class MpcMulOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
GetInputOutputWithSameType() const override { const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}}; static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
} return m;
}
}; };
class MpcMulGradOp : public framework::OperatorWithKernel { class MpcMulGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
auto out_grad_name = framework::GradVarName("Out"); auto out_grad_name = framework::GradVarName("Out");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null."); PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true, PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true,
"Input(Out@GRAD) should not be null."); "Input(Out@GRAD) should not be null.");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims); ctx->SetOutputDim(x_grad_name, x_dims);
} }
if (ctx->HasOutput(y_grad_name)) { if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims); ctx->SetOutputDim(y_grad_name, y_dims);
}
} }
}
}; };
template <typename T> template <typename T>
class MpcMulOpGradMaker : public framework::SingleGradOpDescMaker { class MpcMulOpGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad) const override {
std::unique_ptr<T> retv(new T()); grad->SetType("mpc_mul_grad");
retv->SetType("mpc_mul_grad"); grad->SetInput("X", this->Input("X"));
retv->SetInput("X", this->Input("X")); grad->SetInput("Y", this->Input("Y"));
retv->SetInput("Y", this->Input("Y")); grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); grad->SetAttrMap(this->Attrs());
retv->SetAttrMap(this->Attrs()); }
return retv;
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_mul, ops::MpcMulOp, ops::MpcMulOpMaker, REGISTER_OPERATOR(mpc_mul, ops::MpcMulOp,
ops::MpcMulOpInferVarType, ops::MpcMulOpMaker,
ops::MpcMulOpGradMaker<paddle::framework::OpDesc>); ops::MpcMulOpInferVarType,
ops::MpcMulOpGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_mul_grad, ops::MpcMulGradOp); REGISTER_OPERATOR(mpc_mul_grad, ops::MpcMulGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mpc_mul, ops::MpcMulKernel<paddle::platform::CPUDeviceContext, int64_t>); mpc_mul,
ops::MpcMulKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mpc_mul_grad, mpc_mul_grad,
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#pragma once #pragma once
#include "mpc_op.h" #include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -24,185 +23,170 @@ using Tensor = framework::Tensor; ...@@ -24,185 +23,170 @@ using Tensor = framework::Tensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcMulKernel : public MpcOpKernel<T> { class MpcMulKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<Tensor>("X"); auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Input<Tensor>("Y"); auto *y = ctx.Input<Tensor>("Y");
auto *out = ctx.Output<Tensor>("Out"); auto *out = ctx.Output<Tensor>("Out");
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims"); int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims"); int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims = y->dims();
int x_mat_width = 1; int x_mat_width = 1;
int x_mat_height = 1; int x_mat_height = 1;
int y_mat_width = 1; int y_mat_width = 1;
int y_mat_height = 1; int y_mat_height = 1;
for (size_t i = 1; i < x_dims.size(); i++) { for (size_t i = 1; i < x_dims.size(); i++) {
if (i <= x_num_col_dims) { if (i <= x_num_col_dims) {
x_mat_width *= x_dims[i]; x_mat_width *= x_dims[i];
} else { } else {
x_mat_height *= x_dims[i]; x_mat_height *= x_dims[i];
} }
} }
for (size_t i = 1; i < y_dims.size(); i++) { for (size_t i = 1; i < y_dims.size(); i++) {
if (i <= y_num_col_dims) { if (i <= y_num_col_dims) {
x_mat_width *= y_dims[i]; y_mat_width *= y_dims[i];
} else { } else {
y_mat_height *= y_dims[i]; y_mat_height *= y_dims[i];
} }
} }
Tensor x_matrix; Tensor x_matrix;
Tensor y_matrix; Tensor y_matrix;
x_matrix.ShareDataWith(*x); x_matrix.ShareDataWith(*x);
y_matrix.ShareDataWith(*y); y_matrix.ShareDataWith(*y);
if (x_dims.size() > 3) { x_matrix.Resize({2, x_mat_width, x_mat_height});
x_matrix.Resize({2, x_mat_width, x_mat_height}); y_matrix.Resize({2, y_mat_width, y_mat_height});
out->mutable_data<T>(ctx.GetPlace());
auto out_dim = out->dims();
if (out_dim.size() > 3) {
out->Resize({2, x_mat_width, y_mat_height});
}
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
&x_matrix, &y_matrix, out);
if (out_dim.size() > 3) {
out->Resize(out_dim);
}
} }
if (y_dims.size() > 3) {
y_matrix.Resize({2, y_mat_width, y_mat_height});
}
out->mutable_data<T>(ctx.GetPlace());
auto out_dim = out->dims();
if (out_dim.size() > 3) {
out->Resize({2, x_mat_width, y_mat_height});
}
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
&x_matrix, &y_matrix, out);
if (out_dim.size() > 3) {
out->Resize(out_dim);
}
}
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcMulGradKernel : public MpcOpKernel<T> { class MpcMulGradKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<framework::LoDTensor>("X"); auto* x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y"); auto* y = ctx.Input<framework::LoDTensor>("Y");
auto *dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out")); auto *dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X")); auto *dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y")); auto *dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims"); int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims"); int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims = y->dims();
auto dout_dims = dout->dims(); auto dout_dims = dout->dims();
int x_mat_width = 1; int x_mat_width = 1;
int x_mat_height = 1; int x_mat_height = 1;
int y_mat_width = 1; int y_mat_width = 1;
int y_mat_height = 1; int y_mat_height = 1;
for (size_t i = 1; i < x_dims.size(); i++) { for (size_t i = 1; i < x_dims.size(); i++) {
if (i <= x_num_col_dims) { if (i <= x_num_col_dims) {
x_mat_width *= x_dims[i]; x_mat_width *= x_dims[i];
} else { } else {
x_mat_height *= x_dims[i]; x_mat_height *= x_dims[i];
} }
} }
for (size_t i = 1; i < y_dims.size(); i++) { for (size_t i = 1; i < y_dims.size(); i++) {
if (i <= y_num_col_dims) { if (i <= y_num_col_dims) {
y_mat_width *= y_dims[i]; y_mat_width *= y_dims[i];
} else { } else {
y_mat_height *= y_dims[i]; y_mat_height *= y_dims[i];
} }
}
Tensor x_matrix;
Tensor y_matrix;
Tensor dout_matrix;
x_matrix.ShareDataWith(*x);
y_matrix.ShareDataWith(*y);
dout_matrix.ShareDataWith(*dout);
x_matrix.Resize({2, x_mat_width, x_mat_height});
y_matrix.Resize({2, y_mat_width, y_mat_height});
dout_matrix.Resize({2, x_mat_width, y_mat_height});
if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}
Tensor x_matrix_trans;
Tensor y_matrix_trans;
x_matrix_trans.mutable_data<T>(x->dims(), ctx.GetPlace());
y_matrix_trans.mutable_data<T>(y->dims(), ctx.GetPlace());
x_matrix_trans.Resize({2, x_mat_height, x_mat_width});
y_matrix_trans.Resize({2, y_mat_height, y_mat_width});
auto& dev_ctx = ctx.template device_context<DeviceContext>();
const int Rank = 3;
Eigen::array<int, Rank> permute;
permute[0] = 0;
permute[1] = 2;
permute[2] = 1;
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
if (dx->dims().size() > 3) {
dx->Resize({2, x_mat_width, x_mat_height});
}
auto eigen_in = framework::EigenTensor<T, Rank>::From(y_matrix);
auto eigen_out = framework::EigenTensor<T, Rank>::From(y_matrix_trans);
auto* dev = dev_ctx.eigen_device();
eigen_out.device(*dev) = eigen_in.shuffle(permute);
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
&dout_matrix, &y_matrix_trans, dx);
auto dx_dim = dx->dims();
if (dx_dim.size() > 3) {
dx->Resize(dx_dim);
}
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
if (dy->dims().size() > 3) {
dy->Resize({2, y_mat_width, y_mat_height});
}
auto eigen_in = framework::EigenTensor<T, Rank>::From(x_matrix);
auto eigen_out = framework::EigenTensor<T, Rank>::From(x_matrix_trans);
auto* dev = dev_ctx.eigen_device();
eigen_out.device(*dev) = eigen_in.shuffle(permute);
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
&x_matrix_trans, &dout_matrix, dy);
auto dy_dim = dy->dims();
if (dy_dim.size() > 3) {
dy->Resize(dy_dim);
}
}
} }
Tensor x_matrix;
Tensor y_matrix;
Tensor dout_matrix;
x_matrix.ShareDataWith(*x);
y_matrix.ShareDataWith(*y);
dout_matrix.ShareDataWith(*dout);
if (x_dims.size() > 3) {
x_matrix.Resize({2, x_mat_width, x_mat_height});
}
if (y_dims.size() > 3) {
y_matrix.Resize({2, y_mat_width, y_mat_height});
}
if (dout_dims.size() > 3) {
dout_matrix.Resize({2, x_mat_width, y_mat_height});
}
if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}
Tensor x_matrix_trans;
Tensor y_matrix_trans;
x_matrix_trans.mutable_data<T>(x->dims(), ctx.GetPlace());
y_matrix_trans.mutable_data<T>(y->dims(), ctx.GetPlace());
if (x_dims.size() >= 3) {
x_matrix_trans.Resize({2, x_mat_height, x_mat_width});
}
if (y_dims.size() >= 3) {
y_matrix_trans.Resize({2, y_mat_height, y_mat_width});
}
auto &dev_ctx = ctx.template device_context<DeviceContext>();
const int Rank = 3;
Eigen::array<int, Rank> permute;
permute[0] = 0;
permute[1] = 2;
permute[2] = 1;
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
if (dx->dims().size() > 3) {
dx->Resize({2, x_mat_width, x_mat_height});
}
auto eigen_in = framework::EigenTensor<T, Rank>::From(y_matrix);
auto eigen_out = framework::EigenTensor<T, Rank>::From(y_matrix_trans);
auto *dev = dev_ctx.eigen_device();
eigen_out.device(*dev) = eigen_in.shuffle(permute);
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
&dout_matrix, &y_matrix_trans, dx);
auto dx_dim = dx->dims();
if (dx_dim.size() > 3) {
dx->Resize(dx_dim);
}
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
if (dy->dims().size() > 3) {
dy->Resize({2, y_mat_width, y_mat_height});
}
auto eigen_in = framework::EigenTensor<T, Rank>::From(x_matrix);
auto eigen_out = framework::EigenTensor<T, Rank>::From(x_matrix_trans);
auto *dev = dev_ctx.eigen_device();
eigen_out.device(*dev) = eigen_in.shuffle(permute);
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
&x_matrix_trans, &dout_matrix, dy);
auto dy_dim = dy->dims();
if (dy_dim.size() > 3) {
dy->Resize(dy_dim);
}
}
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
// Description: // Description:
#pragma once #pragma once
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h" #include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "core/privc3/circuit_context.h" #include "core/privc3/circuit_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> class MpcOpKernel : public framework::OpKernelBase { template <typename T>
class MpcOpKernel : public framework::OpKernelBase {
public: public:
using ELEMENT_TYPE = T; using ELEMENT_TYPE = T;
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_instance()->mpc_protocol(), PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_instance()->mpc_protocol(),
"Mpc protocol is not yet initialized in executor"); "Mpc protocol is not yet initialized in executor");
std::shared_ptr<aby3::CircuitContext> mpc_ctx( std::shared_ptr<aby3::CircuitContext> mpc_ctx(mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context()); mpc::ContextHolder::template run_with_context<>(&ctx, mpc_ctx,
mpc::ContextHolder::template run_with_context<>(&ctx, mpc_ctx, [&] { ComputeImpl(ctx); });
[&] { ComputeImpl(ctx); }); }
} virtual void ComputeImpl(const framework::ExecutionContext& ctx) const = 0;
virtual void ComputeImpl(const framework::ExecutionContext &ctx) const = 0;
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -18,20 +18,20 @@ ...@@ -18,20 +18,20 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
// forward op defination //forward op defination
class MpcReluOp : public framework::OperatorWithKernel { class MpcReluOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim("X"); auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Y", in_dims); ctx->SetOutputDim("Y", in_dims);
} }
}; };
// forward input & output defination //forward input & output defination
class MpcReluOpMaker : public framework::OpProtoAndCheckerMaker { class MpcReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "The input tensor."); AddInput("X", "The input tensor.");
AddOutput("Y", "Output of relu_op"); AddOutput("Y", "Output of relu_op");
...@@ -41,43 +41,46 @@ Mpc Relu Operator. ...@@ -41,43 +41,46 @@ Mpc Relu Operator.
} }
}; };
// backward op defination //backward op defination
class MpcReluGradOp : public framework::OperatorWithKernel { class MpcReluGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim(framework::GradVarName("Y")); auto in_dims = ctx->GetInputDim(framework::GradVarName("Y"));
ctx->SetOutputDim(framework::GradVarName("X"), in_dims); ctx->SetOutputDim(framework::GradVarName("X"), in_dims);
} }
}; };
// backward type, input & output defination //backward type, input & output defination
template <typename T> template <typename T>
class MpcReluGradMaker : public framework::SingleGradOpDescMaker { class MpcReluGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override { protected:
auto *op = new T(); void Apply(GradOpPtr<T> grad) const override {
op->SetType("mpc_relu_grad"); grad->SetType("mpc_relu_grad");
op->SetInput("Y", this->Output("Y")); grad->SetInput("Y", this->Output("Y"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); grad->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetAttrMap(this->Attrs()); grad->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
return std::unique_ptr<T>(op); }
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(mpc_relu, ops::MpcReluOp, ops::MpcReluOpMaker, REGISTER_OPERATOR(mpc_relu,
ops::MpcReluOp,
ops::MpcReluOpMaker,
ops::MpcReluGradMaker<paddle::framework::OpDesc>); ops::MpcReluGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_relu_grad, ops::MpcReluGradOp); REGISTER_OPERATOR(mpc_relu_grad, ops::MpcReluGradOp);
REGISTER_OP_CPU_KERNEL(mpc_relu, ops::MpcReluKernel<CPU, int64_t>); REGISTER_OP_CPU_KERNEL(mpc_relu,
REGISTER_OP_CPU_KERNEL(mpc_relu_grad, ops::MpcReluGradKernel<CPU, int64_t>); ops::MpcReluKernel<CPU, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_relu_grad,
ops::MpcReluGradKernel<CPU, int64_t>);
...@@ -14,43 +14,37 @@ ...@@ -14,43 +14,37 @@
#pragma once #pragma once
#include "mpc_op.h" #include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
// Define forward computation //Define forward computation
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcReluKernel : public MpcOpKernel<T> { class MpcReluKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext& ctx) const override {
const Tensor *in_t = ctx.Input<Tensor>("X"); const Tensor* in_t = ctx.Input<Tensor>("X");
Tensor *out_t = ctx.Output<Tensor>("Y"); Tensor* out_t = ctx.Output<Tensor>("Y");
auto x = in_t->data<T>(); auto x = in_t->data<T>();
auto y = out_t->mutable_data<T>(ctx.GetPlace()); auto y = out_t->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, "Protocol %s is not yet created in MPC Protocol.");
"Protocol %s is not yet created in MPC Protocol."); mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->relu(in_t,out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->relu(
in_t, out_t);
} }
}; };
// Define backward computation //Define backward computation
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcReluGradKernel : public MpcOpKernel<T> { class MpcReluGradKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext& ctx) const override {
auto *dy_t = ctx.Input<Tensor>(framework::GradVarName("Y")); auto* dy_t = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto *y_t = ctx.Input<Tensor>("Y"); auto* y_t = ctx.Input<Tensor>("Y");
auto *dx_t = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dx = dx_t->mutable_data<T>(ctx.GetPlace()); auto dx = dx_t->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance() mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->relu_grad(y_t, dy_t, dx_t, 0.0);
->mpc_protocol() }
->mpc_operators()
->relu_grad(y_t, dy_t, dx_t, 0.0);
}
}; };
} // namespace operaters }// namespace operaters
} // namespace paddle }// namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#include "mpc_sgd_op.h" #include "mpc_sgd_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -20,77 +20,77 @@ namespace operators { ...@@ -20,77 +20,77 @@ namespace operators {
class MpcSGDOp : public framework::OperatorWithKernel { class MpcSGDOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"), PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of MPCSGDOp should not be null."); "Input(Param) of MPCSGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"), PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of MPCSGDOp should not be null."); "Input(Grad) of MPCSGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"), PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of MPCSGDOp should not be null."); "Input(LearningRate) of MPCSGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of MPCSGDOp should not be null."); "Output(ParamOut) of MPCSGDOp should not be null.");
auto lr_dims = ctx->GetInputDim("LearningRate"); auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(framework::product(lr_dims), 0, PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
"Maybe the Input variable LearningRate has not " "Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm " "been initialized. You may need to confirm "
"if you put exe.run(startup_program) " "if you put exe.run(startup_program) "
"after optimizer.minimize function."); "after optimizer.minimize function.");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element"); "Learning rate should have 1 element");
auto param_dim = ctx->GetInputDim("Param"); auto param_dim = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] == if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) { framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"), param_dim, ctx->GetInputDim("Grad"),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"MPCSGD Operator's input Param and Grad dimensions do not match. " "MPCSGD Operator's input Param and Grad dimensions do not match. "
"The Param %s shape is [%s], but the Grad %s shape is [%s].", "The Param %s shape is [%s], but the Grad %s shape is [%s].",
ctx->Inputs("Param")[0], param_dim, ctx->Inputs("Grad")[0], ctx->Inputs("Param")[0], param_dim, ctx->Inputs("Grad")[0],
ctx->GetInputDim("Grad"))); ctx->GetInputDim("Grad")));
} }
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
} }
protected: protected:
framework::OpKernelType framework::OpKernelType GetExpectedKernelType(
GetExpectedKernelType(const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
class MpcSGDOpInferVarType : public framework::VarTypeInference { class MpcSGDOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto &input_var_n = ctx->Input("Param")[0]; auto in_var_type = ctx->GetInputType("Param");
auto in_var_type = ctx->GetType(input_var_n); PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS || in_var_type == framework::proto::VarType::LOD_TENSOR,
in_var_type == framework::proto::VarType::LOD_TENSOR, "The input Var's type should be LoDtensor or SelectedRows,"
"The input Var's type should be LoDtensor or SelectedRows," " but the received var(%s)'s type is %s",
" but the received var(%s)'s type is %s", ctx->InputVarName("Param"), in_var_type);
input_var_n, in_var_type); ctx->SetOutputType("ParamOut", in_var_type);
for (auto &out_var_n : ctx->Output("ParamOut")) { //for (auto &out_var_n : framework::StaticGraphVarTypeInference::Output(ctx, "ParamOut")) {
if (ctx->GetType(out_var_n) != in_var_type) { // if (ctx->GetVarType(out_var_n) != in_var_type) {
ctx->SetType(out_var_n, in_var_type); // ctx->SetType(out_var_n, in_var_type);
} //}
} //}
} }
}; };
class MpcSGDOpMaker : public framework::OpProtoAndCheckerMaker { class MpcSGDOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Param", "(Tensor or SelectedRows) Input parameter"); AddInput("Param", "(Tensor or SelectedRows) Input parameter");
AddInput("LearningRate", "(Tensor) Learning rate of MPCSGD"); AddInput("LearningRate", "(Tensor) Learning rate of MPCSGD");
AddInput("Grad", "(Tensor or SelectedRows) Input gradient"); AddInput("Grad", "(Tensor or SelectedRows) Input gradient");
AddOutput("ParamOut", AddOutput("ParamOut",
"(Tensor or SelectedRows, same with Param) " "(Tensor or SelectedRows, same with Param) "
"Output parameter, should share the same memory with Param"); "Output parameter, should share the same memory with Param");
AddComment(R"DOC( AddComment(R"DOC(
MPCSGD operator MPCSGD operator
...@@ -102,13 +102,13 @@ $$param\_out = param - learning\_rate * grad$$ ...@@ -102,13 +102,13 @@ $$param\_out = param - learning\_rate * grad$$
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(
mpc_sgd, ops::MpcSGDOp, ops::MpcSGDOpMaker, mpc_sgd, ops::MpcSGDOp, ops::MpcSGDOpMaker,
// paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
ops::MpcSGDOpInferVarType); ops::MpcSGDOpInferVarType);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mpc_sgd, ops::MpcSGDOpKernel<paddle::platform::CPUDeviceContext, int64_t>); mpc_sgd,
ops::MpcSGDOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#pragma once #pragma once
#include "mpc_op.h" #include "mpc_op.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcSGDOpKernel : public MpcOpKernel<T> { class MpcSGDOpKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override{
const auto *param_var = ctx.InputVar("Param"); const auto *param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true, PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
"The Var(%s)'s type should be LoDTensor, " "The Var(%s)'s type should be LoDTensor, "
"but the received is %s", "but the received is %s",
ctx.Inputs("Param").front(), ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())); framework::ToTypeName(param_var->Type()));
const auto *grad_var = ctx.InputVar("Grad"); const auto *grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true, PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true,
"The Var(%s)'s type should be LoDTensor, " "The Var(%s)'s type should be LoDTensor, "
"but the received is %s", "but the received is %s",
ctx.Inputs("Grad").front(), ctx.InputNames("Grad").front(),
framework::ToTypeName(grad_var->Type())); framework::ToTypeName(grad_var->Type()));
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate"); const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
const auto *param = ctx.Input<framework::Tensor>("Param"); const auto *param = ctx.Input<framework::Tensor>("Param");
const auto *grad = ctx.Input<framework::Tensor>("Grad"); const auto *grad = ctx.Input<framework::Tensor>("Grad");
auto *param_out = ctx.Output<framework::Tensor>("ParamOut"); auto *param_out = ctx.Output<framework::Tensor>("ParamOut");
auto sz = param_out->numel(); auto sz = param_out->numel();
PADDLE_ENFORCE_EQ(param->numel(), sz); PADDLE_ENFORCE_EQ(param->numel(), sz);
PADDLE_ENFORCE_EQ(grad->numel(), sz); PADDLE_ENFORCE_EQ(grad->numel(), sz);
const double *lr = learning_rate->data<double>(); const double *lr = learning_rate->data<double>();
// const T *param_data = param->data<T>();
// const T *grad_data = grad->data<T>(); param_out->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, "Protocol %s is not yet created in MPC Protocol.");
T *out_data = param_out->mutable_data<T>(ctx.GetPlace()); // update parameters
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, framework::Tensor temp;
"Protocol %s is not yet created in MPC Protocol."); temp.mutable_data<T>(param->dims(), ctx.GetPlace());
// update parameters mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(grad, lr[0], &temp);
framework::Tensor temp; mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(param, &temp, param_out);
temp.mutable_data<T>(param->dims(), ctx.GetPlace()); }
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(
grad, lr[0], &temp);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(
param, &temp, param_out);
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -117,21 +117,19 @@ MpcSigmoidCrossEntropyWithLogits Operator. ...@@ -117,21 +117,19 @@ MpcSigmoidCrossEntropyWithLogits Operator.
}; };
template <typename T> template <typename T>
class MpcSigmoidCrossEntropyWithLogitsGradOpMaker : public framework::SingleGradOpDescMaker { class MpcSigmoidCrossEntropyWithLogitsGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad) const override {
std::unique_ptr<T> retv(new T()); grad->SetType("mpc_sigmoid_cross_entropy_with_logits_grad");
retv->SetType("mpc_sigmoid_cross_entropy_with_logits_grad"); grad->SetInput("X", this->Input("X"));
retv->SetInput("X", this->Input("X")); grad->SetInput("Label", this->Input("Label"));
retv->SetInput("Label", this->Input("Label")); grad->SetInput("Out", this->Output("Out"));
retv->SetInput("Out", this->Output("Out")); grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad->SetAttrMap(this->Attrs());
retv->SetAttrMap(this->Attrs());
return retv;
} }
}; };
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#include "mpc_square_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "mpc_square_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
class MpcSquareOp : public framework::OperatorWithKernel { class MpcSquareOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx->HasInput("X"), true,
"Input(X) of MpcSquareOp should not be null.")); platform::errors::NotFound("Input(X) of MpcSquareOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx->HasOutput("Out"), true,
"Output(Out) of MpcSquareOp should not be null.")); platform::errors::NotFound("Output(Out) of MpcSquareOp should not be null."));
ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
class MpcSquareOpMaker : public framework::OpProtoAndCheckerMaker { class MpcSquareOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor), The first input tensor of mpc square op."); AddInput("X", "(Tensor), The first input tensor of mpc square op.");
AddOutput("Out", "(Tensor), The output tensor of mpc square op."); AddOutput("Out", "(Tensor), The output tensor of mpc square op.");
AddComment(R"DOC( AddComment(R"DOC(
MPC square Operator.. MPC square Operator..
)DOC"); )DOC");
} }
}; };
class MpcSquareGradOp : public framework::OperatorWithKernel { class MpcSquareGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
ctx->ShareDim(framework::GradVarName("Out"), framework::GradVarName("X")); ctx->ShareDim(framework::GradVarName("Out"), framework::GradVarName("X"));
ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X")); ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X"));
} }
}; };
template <typename T> template <typename T>
class MpcSquareGradOpMaker : public framework::SingleGradOpDescMaker { class MpcSquareGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad) const override {
std::unique_ptr<T> retv(new T()); grad->SetType("mpc_square_grad");
retv->SetType("mpc_square_grad"); grad->SetInput("X", this->Input("X"));
retv->SetInput("X", this->Input("X")); grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); }
return retv;
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_square, ops::MpcSquareOp, ops::MpcSquareOpMaker, REGISTER_OPERATOR(mpc_square, ops::MpcSquareOp,
ops::MpcSquareGradOpMaker<paddle::framework::OpDesc>); ops::MpcSquareOpMaker,
ops::MpcSquareGradOpMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_square_grad, ops::MpcSquareGradOp); REGISTER_OPERATOR(mpc_square_grad, ops::MpcSquareGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mpc_square, mpc_square,
ops::MpcSquareKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::MpcSquareKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mpc_square_grad, mpc_square_grad,
ops::MpcSquareGradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::MpcSquareGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#pragma once #pragma once
#include "mpc_op.h" #include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -23,33 +23,31 @@ using Tensor = framework::Tensor; ...@@ -23,33 +23,31 @@ using Tensor = framework::Tensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcSquareKernel : public MpcOpKernel<T> { class MpcSquareKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<Tensor>("X"); auto *in_x_t = ctx.Input<Tensor>("X");
auto *out_t = ctx.Output<Tensor>("Out"); auto *out_t = ctx.Output<Tensor>("Out");
out_t->mutable_data<T>(ctx.GetPlace()); out_t->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul(in_x_t, in_x_t, out_t);
in_x_t, in_x_t, out_t); }
}
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcSquareGradKernel : public MpcOpKernel<T> { class MpcSquareGradKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<Tensor>("X"); auto *in_x_t = ctx.Input<Tensor>("X");
auto *dout_t = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dout_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx_t = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
if (dx_t != nullptr) { if (dx_t != nullptr) {
// allocate memory on device. // allocate memory on device.
dx_t->mutable_data<T>(ctx.GetPlace()); dx_t->mutable_data<T>(ctx.GetPlace());
// dx = dout * 2 * x // dx = dout * 2 * x
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(in_x_t, 2.0, dx_t);
in_x_t, 2.0, dx_t); mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul(dx_t, dout_t, dx_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul( }
dx_t, dout_t, dx_t);
} }
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "mpc_sum_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_sum_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -29,131 +28,135 @@ using Tensor = framework::Tensor; ...@@ -29,131 +28,135 @@ using Tensor = framework::Tensor;
class MpcSumOp : public framework::OperatorWithKernel { class MpcSumOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->HasInputs("X"), true, ctx->HasInputs("X"), true,
platform::errors::NotFound( platform::errors::NotFound("Input(X) of MpcElementwiseAddOp should not be null."));
"Input(X) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ( ctx->HasOutput("Out"), true,
ctx->HasOutput("Out"), true, platform::errors::NotFound("Output(Out) of MulOp should not be null."));
platform::errors::NotFound("Output(Out) of MulOp should not be null."));
auto x_var_types = ctx->GetInputsVarType("X");
auto x_var_types = ctx->GetInputsVarType("X"); auto x_dims = ctx->GetInputsDim("X");
auto x_dims = ctx->GetInputsDim("X"); auto N = x_dims.size();
auto N = x_dims.size(); PADDLE_ENFORCE_GT(
PADDLE_ENFORCE_GT( N, 0,
N, 0, "ShapeError: The input tensor X's dimensions of SumOp " "ShapeError: The input tensor X's dimensions of SumOp "
"should be larger than 0. But received X's dimensions %d, " "should be larger than 0. But received X's dimensions %d, "
"X's shape = [%s].", "X's shape = [%s].",
N, &x_dims); N, &x_dims);
if (N == 1) { if (N == 1) {
VLOG(3) << "Warning: SumOp have only one input, may waste memory"; VLOG(3) << "Warning: SumOp have only one input, may waste memory";
} }
framework::DDim in_dim({0}); framework::DDim in_dim({0});
for (size_t i = 0; i < x_dims.size(); ++i) { for (size_t i = 0; i < x_dims.size(); ++i) {
auto &x_dim = x_dims[i]; auto& x_dim = x_dims[i];
// x_dim.size() == 1 means the real dim of selected rows is [0] // x_dim.size() == 1 means the real dim of selected rows is [0]
if (x_var_types[i] == framework::proto::VarType::SELECTED_ROWS && if (x_var_types[i] == framework::proto::VarType::SELECTED_ROWS &&
x_dim.size() == 1) { x_dim.size() == 1) {
continue; continue;
}
if (framework::product(x_dim) == 0) {
continue;
}
if (framework::product(in_dim) == 0) {
in_dim = x_dim;
} else {
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
in_dim, x_dim,
"ShapeError: The input tensor X of SumOp must have same shape."
"But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
in_dim, i, x_dim);
} else {
PADDLE_ENFORCE_EQ(
in_dim.size(), x_dim.size(),
"ShapeError: The input tensor X of SumOp must have same "
"dimensions. But received X[0]'s dimensions = %d, X[0]'s shape = "
"[%s], X[%d]'s dimensions = %d, X[%d]'s shape = [%s].",
in_dim.size(), in_dim, i, x_dim.size(), i, x_dim);
// if in_dim or x_dim has -1, not check equal
for (int j = 0; j < x_dim.size(); ++j) {
if (x_dim[j] == -1 || in_dim[j] == -1) {
continue;
} }
PADDLE_ENFORCE_EQ( if (framework::product(x_dim) == 0) {
in_dim[j], x_dim[j], continue;
"ShapeError: The input tensor X of SumOp must have same shape " }
"if not -1." if (framework::product(in_dim) == 0) {
"But received X[0]'s shape = [%s], X[%d]'s shape = [%s].", in_dim = x_dim;
in_dim, i, x_dim); } else {
} if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
in_dim, x_dim,
"ShapeError: The input tensor X of SumOp must have same shape."
"But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
in_dim, i, x_dim);
} else {
PADDLE_ENFORCE_EQ(
in_dim.size(), x_dim.size(),
"ShapeError: The input tensor X of SumOp must have same "
"dimensions. But received X[0]'s dimensions = %d, X[0]'s shape = "
"[%s], X[%d]'s dimensions = %d, X[%d]'s shape = [%s].",
in_dim.size(), in_dim, i, x_dim.size(), i, x_dim);
// if in_dim or x_dim has -1, not check equal
for (int j = 0; j < x_dim.size(); ++j) {
if (x_dim[j] == -1 || in_dim[j] == -1) {
continue;
}
PADDLE_ENFORCE_EQ(
in_dim[j], x_dim[j],
"ShapeError: The input tensor X of SumOp must have same shape "
"if not -1."
"But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
in_dim, i, x_dim);
}
}
}
} }
}
ctx->SetOutputDim("Out", in_dim);
ctx->ShareLoD("X", /*->*/ "Out");
} }
ctx->SetOutputDim("Out", in_dim);
ctx->ShareLoD("X", /*->*/ "Out");
}
}; };
class MpcSumOpMaker : public framework::OpProtoAndCheckerMaker { class MpcSumOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", AddInput("X",
"A Varaible list. The shape and data type of the list elements" "A Varaible list. The shape and data type of the list elements"
"should be consistent. Variable can be multi-dimensional Tensor" "should be consistent. Variable can be multi-dimensional Tensor"
"or LoDTensor, and data types can be: float32, float64, int32, " "or LoDTensor, and data types can be: float32, float64, int32, "
"int64.") "int64.")
.AsDuplicable(); .AsDuplicable();
AddOutput("Out", "the sum of input :code:`x`. its shape and data types are " AddOutput("Out",
"consistent with :code:`x`."); "the sum of input :code:`x`. its shape and data types are "
AddAttr<bool>("use_mkldnn", "consistent with :code:`x`.");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddComment(R"DOC(This OP is used to sum one or more Tensor or LoDTensor AddComment(R"DOC(This OP is used to sum one or more Tensor or LoDTensor
of the input. If the input is LoDTensor, the output only of the input. If the input is LoDTensor, the output only
shares LoD information with the first input.)DOC"); shares LoD information with the first input.)DOC");
} }
}; };
class MpcSumGradMaker : public framework::GradOpDescMakerBase { class MpcSumGradMaker : public framework::GradOpDescMakerBase {
public: public:
using framework::GradOpDescMakerBase::GradOpDescMakerBase; using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override { std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
auto x_grads = InputGrad("X", false); auto x_grads = InputGrad("X", false);
std::vector<std::unique_ptr<framework::OpDesc>> grad_ops; std::vector<std::unique_ptr<framework::OpDesc>> grad_ops;
grad_ops.reserve(x_grads.size()); grad_ops.reserve(x_grads.size());
auto og = OutputGrad("Out"); auto og = OutputGrad("Out");
std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops), std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
[&og](const std::string &x_grad) { [&og](const std::string& x_grad) {
auto *grad_op = new framework::OpDesc(); auto* grad_op = new framework::OpDesc();
grad_op->SetType("scale"); grad_op->SetType("scale");
grad_op->SetInput("X", og); grad_op->SetInput("X", og);
grad_op->SetOutput("Out", {x_grad}); grad_op->SetOutput("Out", {x_grad});
grad_op->SetAttr("scale", 1.0f); grad_op->SetAttr("scale", 1.0f);
return std::unique_ptr<framework::OpDesc>(grad_op); return std::unique_ptr<framework::OpDesc>(grad_op);
}); });
return grad_ops; return grad_ops;
} }
}; };
DECLARE_INPLACE_OP_INFERER(MpcSumInplace, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(MpcSumInplace, {"X", "Out"});
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
// REGISTER_OP_WITHOUT_GRADIENT(mpc_sum, ops::MpcSumOp, ops::MpcSumOpMaker); //REGISTER_OP_WITHOUT_GRADIENT(mpc_sum, ops::MpcSumOp, ops::MpcSumOpMaker);
REGISTER_OPERATOR(mpc_sum, ops::MpcSumOp, ops::MpcSumOpMaker, REGISTER_OPERATOR(mpc_sum, ops::MpcSumOp,
ops::MpcSumGradMaker, ops::MpcSumInplace); ops::MpcSumOpMaker,
ops::MpcSumGradMaker,
ops::MpcSumInplace);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(mpc_sum, ops::MpcSumKernel<paddle::platform::CPUDeviceContext, int64_t>);
mpc_sum, ops::MpcSumKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#pragma once #pragma once
#include "mpc_op.h" #include "mpc_op.h"
...@@ -23,62 +23,57 @@ using Tensor = framework::Tensor; ...@@ -23,62 +23,57 @@ using Tensor = framework::Tensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcSumKernel : public MpcOpKernel<T> { class MpcSumKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto in_vars = ctx.MultiInputVar("X"); auto in_vars = ctx.MultiInputVar("X");
size_t in_num = in_vars.size(); size_t in_num = in_vars.size();
auto out_var = ctx.OutputVar("Out"); auto out_var = ctx.OutputVar("Out");
bool in_place = out_var == in_vars[0]; bool in_place = out_var == in_vars[0];
if (out_var->IsType<framework::LoDTensor>()) { if (out_var->IsType<framework::LoDTensor>()) {
auto *out = out_var->GetMutable<framework::LoDTensor>(); auto *out = out_var->GetMutable<framework::LoDTensor>();
auto *out_ptr = out->mutable_data<T>(ctx.GetPlace()); auto *out_ptr = out->mutable_data<T>(ctx.GetPlace());
if (in_num >= 1 && in_vars[0]->IsType<framework::LoDTensor>()) { if (in_num >= 1 && in_vars[0]->IsType<framework::LoDTensor>()) {
auto &in_0_tensor = in_vars[0]->Get<framework::LoDTensor>(); auto &in_0_tensor = in_vars[0]->Get<framework::LoDTensor>();
if (in_0_tensor.numel() > 0) { if (in_0_tensor.numel() > 0) {
in_place = (in_0_tensor.data<T>() == out_ptr); in_place = (in_0_tensor.data<T>() == out_ptr);
} }
} }
int start = in_place ? 1 : 0; int start = in_place ? 1 : 0;
if (!in_place) { if (!in_place) {
if ((in_num >= 2) && in_vars[0]->IsType<framework::LoDTensor>() && if ((in_num >= 2) && in_vars[0]->IsType<framework::LoDTensor>() &&
in_vars[1]->IsType<framework::LoDTensor>()) { in_vars[1]->IsType<framework::LoDTensor>()) {
auto &in_0 = in_vars[0]->Get<framework::LoDTensor>(); auto &in_0 = in_vars[0]->Get<framework::LoDTensor>();
auto &in_1 = in_vars[1]->Get<framework::LoDTensor>(); auto &in_1 = in_vars[1]->Get<framework::LoDTensor>();
if (in_0.numel() && in_1.numel()) { if (in_0.numel() && in_1.numel()) {
mpc::MpcInstance::mpc_instance() mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(&in_0, &in_1, out);
->mpc_protocol() start = 2;
->mpc_operators() }
->add(&in_0, &in_1, out); }
start = 2; if (start != 2) {
} auto t = framework::EigenVector<T>::Flatten(*out);
} auto &device_ctx = ctx.template device_context<DeviceContext>();
if (start != 2) { t.device(*device_ctx.eigen_device()) = t.constant(static_cast<T>(0));
auto t = framework::EigenVector<T>::Flatten(*out); }
auto &device_ctx = ctx.template device_context<DeviceContext>(); }
t.device(*device_ctx.eigen_device()) = t.constant(static_cast<T>(0));
}
}
// If in_place, just skip the first tensor // If in_place, just skip the first tensor
for (size_t i = start; i < in_num; i++) { for (size_t i = start; i < in_num; i++) {
if (in_vars[i]->IsType<framework::LoDTensor>()) { if (in_vars[i]->IsType<framework::LoDTensor>()) {
auto &in_t = in_vars[i]->Get<framework::LoDTensor>(); auto &in_t = in_vars[i]->Get<framework::LoDTensor>();
if (in_t.numel() == 0) { if (in_t.numel() == 0) {
continue; continue;
} }
mpc::MpcInstance::mpc_instance() mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(out, &in_t, out);
->mpc_protocol() } else {
->mpc_operators() PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
->add(out, &in_t, out); }
} else { }
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows."); }else {
PADDLE_THROW("Unexpected branch, output variable type is %s",
framework::ToTypeName(out_var->Type()));
} }
}
} else {
PADDLE_THROW("Unexpected branch, output variable type is %s",
framework::ToTypeName(out_var->Type()));
} }
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
此差异已折叠。
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/framework/tensor.h"
namespace aby3 { namespace aby3 {
...@@ -28,271 +28,274 @@ using paddle::framework::Tensor; ...@@ -28,271 +28,274 @@ using paddle::framework::Tensor;
class PaddleTensorTest : public ::testing::Test { class PaddleTensorTest : public ::testing::Test {
public: public:
std::shared_ptr<TensorAdapterFactory> _tensor_factory;
CPUDeviceContext _cpu_ctx; std::shared_ptr<TensorAdapterFactory> _tensor_factory;
CPUDeviceContext _cpu_ctx;
void SetUp() { virtual ~PaddleTensorTest() noexcept {}
_tensor_factory = std::make_shared<PaddleTensorFactory>(&_cpu_ctx);
} void SetUp() {
_tensor_factory = std::make_shared<PaddleTensorFactory>(&_cpu_ctx);
}
}; };
TEST_F(PaddleTensorTest, factory_test) { TEST_F(PaddleTensorTest, factory_test) {
EXPECT_NO_THROW(_tensor_factory->template create<int64_t>()); EXPECT_NO_THROW(_tensor_factory->template create<int64_t>());
std::vector<size_t> shape = {2, 3}; std::vector<size_t> shape = { 2, 3 };
EXPECT_NO_THROW(_tensor_factory->template create<int64_t>(shape)); EXPECT_NO_THROW(_tensor_factory->template create<int64_t>(shape));
} }
TEST_F(PaddleTensorTest, ctor_test) { TEST_F(PaddleTensorTest, ctor_test) {
Tensor t; Tensor t;
// t holds no memory // t holds no memory
EXPECT_THROW({ PaddleTensor<int64_t> pt(&_cpu_ctx, t); }, EXPECT_THROW({ PaddleTensor<int64_t> pt(&_cpu_ctx, t); }, ::paddle::platform::EnforceNotMet);
::paddle::platform::EnforceNotMet); t.template mutable_data<int64_t>(_cpu_ctx.GetPlace());
t.template mutable_data<int64_t>(_cpu_ctx.GetPlace()); EXPECT_NO_THROW({ PaddleTensor<int64_t> pt(&_cpu_ctx, t); });
EXPECT_NO_THROW({ PaddleTensor<int64_t> pt(&_cpu_ctx, t); });
} }
TEST_F(PaddleTensorTest, shape_test) { TEST_F(PaddleTensorTest, shape_test) {
std::vector<size_t> shape = {2, 3}; std::vector<size_t> shape = { 2, 3 };
auto pt = _tensor_factory->template create<int64_t>(shape); auto pt = _tensor_factory->template create<int64_t>(shape);
EXPECT_EQ(shape.size(), pt->shape().size()); EXPECT_EQ(shape.size(), pt->shape().size());
bool eq = std::equal(shape.begin(), shape.end(), pt->shape().begin()); bool eq = std::equal(shape.begin(), shape.end(), pt->shape().begin());
EXPECT_TRUE(eq); EXPECT_TRUE(eq);
EXPECT_EQ(6u, pt->numel()); EXPECT_EQ(6u, pt->numel());
} }
TEST_F(PaddleTensorTest, reshape_test) { TEST_F(PaddleTensorTest, reshape_test) {
std::vector<size_t> shape = {2, 3}; std::vector<size_t> shape = { 2, 3 };
auto pt = _tensor_factory->template create<int64_t>(); auto pt = _tensor_factory->template create<int64_t>();
pt->reshape(shape); pt->reshape(shape);
EXPECT_EQ(shape.size(), pt->shape().size()); EXPECT_EQ(shape.size(), pt->shape().size());
bool eq = std::equal(shape.begin(), shape.end(), pt->shape().begin()); bool eq = std::equal(shape.begin(), shape.end(), pt->shape().begin());
EXPECT_TRUE(eq); EXPECT_TRUE(eq);
} }
TEST_F(PaddleTensorTest, add_test) { TEST_F(PaddleTensorTest, add_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape); auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 1; pt0->data()[0] = 1;
pt1->data()[0] = 2; pt1->data()[0] = 2;
pt0->add(pt1.get(), pt2.get()); pt0->add(pt1.get(), pt2.get());
EXPECT_EQ(3, pt2->data()[0]); EXPECT_EQ(3, pt2->data()[0]);
} }
TEST_F(PaddleTensorTest, sub_test) { TEST_F(PaddleTensorTest, sub_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape); auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 2; pt0->data()[0] = 2;
pt1->data()[0] = 1; pt1->data()[0] = 1;
pt0->sub(pt1.get(), pt2.get()); pt0->sub(pt1.get(), pt2.get());
EXPECT_EQ(1, pt2->data()[0]); EXPECT_EQ(1, pt2->data()[0]);
} }
TEST_F(PaddleTensorTest, negative_test) { TEST_F(PaddleTensorTest, negative_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 2; pt0->data()[0] = 2;
pt0->negative(pt1.get()); pt0->negative(pt1.get());
EXPECT_EQ(-2, pt1->data()[0]); EXPECT_EQ(-2, pt1->data()[0]);
} }
TEST_F(PaddleTensorTest, mul_test) { TEST_F(PaddleTensorTest, mul_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape); auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 7; pt0->data()[0] = 7;
pt1->data()[0] = 3; pt1->data()[0] = 3;
pt0->mul(pt1.get(), pt2.get()); pt0->mul(pt1.get(), pt2.get());
EXPECT_EQ(21, pt2->data()[0]); EXPECT_EQ(21, pt2->data()[0]);
} }
TEST_F(PaddleTensorTest, div_test) { TEST_F(PaddleTensorTest, div_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape); auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 7; pt0->data()[0] = 7;
pt1->data()[0] = 3; pt1->data()[0] = 3;
pt0->div(pt1.get(), pt2.get()); pt0->div(pt1.get(), pt2.get());
EXPECT_EQ(2, pt2->data()[0]); EXPECT_EQ(2, pt2->data()[0]);
} }
TEST_F(PaddleTensorTest, matmul_test) { TEST_F(PaddleTensorTest, matmul_test) {
std::vector<size_t> shape0 = {2, 3}; std::vector<size_t> shape0 = { 2, 3 };
std::vector<size_t> shape1 = {3, 2}; std::vector<size_t> shape1 = { 3, 2 };
std::vector<size_t> shape2 = {2, 2}; std::vector<size_t> shape2 = { 2, 2 };
auto pt0 = _tensor_factory->template create<int64_t>(shape0); auto pt0 = _tensor_factory->template create<int64_t>(shape0);
auto pt1 = _tensor_factory->template create<int64_t>(shape1); auto pt1 = _tensor_factory->template create<int64_t>(shape1);
auto pt2 = _tensor_factory->template create<int64_t>(shape2); auto pt2 = _tensor_factory->template create<int64_t>(shape2);
for (size_t i = 0; i < 6; ++i) { for (size_t i = 0; i < 6; ++i) {
pt0->data()[i] = i; pt0->data()[i] = i;
pt1->data()[i] = i; pt1->data()[i] = i;
} }
pt0->mat_mul(pt1.get(), pt2.get()); pt0->mat_mul(pt1.get(), pt2.get());
// | 0 1 2 | | 0 1 | | 10 13 | // | 0 1 2 | | 0 1 | | 10 13 |
// | 3 4 5 | x | 2 3 | = | 28 40 | // | 3 4 5 | x | 2 3 | = | 28 40 |
// | 4 5 | // | 4 5 |
std::vector<int64_t> res = {10, 13, 28, 40}; std::vector<int64_t> res = { 10, 13, 28, 40 };
bool eq = std::equal(res.begin(), res.end(), pt2->data()); bool eq = std::equal(res.begin(), res.end(), pt2->data());
EXPECT_TRUE(eq); EXPECT_TRUE(eq);
} }
TEST_F(PaddleTensorTest, xor_test) { TEST_F(PaddleTensorTest, xor_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape); auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 3; pt0->data()[0] = 3;
pt1->data()[0] = 7; pt1->data()[0] = 7;
pt0->bitwise_xor(pt1.get(), pt2.get()); pt0->bitwise_xor(pt1.get(), pt2.get());
EXPECT_EQ(4, pt2->data()[0]); EXPECT_EQ(4, pt2->data()[0]);
} }
TEST_F(PaddleTensorTest, and_test) { TEST_F(PaddleTensorTest, and_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape); auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 3; pt0->data()[0] = 3;
pt1->data()[0] = 7; pt1->data()[0] = 7;
pt0->bitwise_and(pt1.get(), pt2.get()); pt0->bitwise_and(pt1.get(), pt2.get());
EXPECT_EQ(3, pt2->data()[0]); EXPECT_EQ(3, pt2->data()[0]);
} }
TEST_F(PaddleTensorTest, or_test) { TEST_F(PaddleTensorTest, or_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape); auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 3; pt0->data()[0] = 3;
pt1->data()[0] = 7; pt1->data()[0] = 7;
pt0->bitwise_or(pt1.get(), pt2.get()); pt0->bitwise_or(pt1.get(), pt2.get());
EXPECT_EQ(7, pt2->data()[0]); EXPECT_EQ(7, pt2->data()[0]);
} }
TEST_F(PaddleTensorTest, not_test) { TEST_F(PaddleTensorTest, not_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 0; pt0->data()[0] = 0;
pt0->bitwise_not(pt1.get()); pt0->bitwise_not(pt1.get());
EXPECT_EQ(-1, pt1->data()[0]); EXPECT_EQ(-1, pt1->data()[0]);
} }
TEST_F(PaddleTensorTest, lshift_test) { TEST_F(PaddleTensorTest, lshift_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 2; pt0->data()[0] = 2;
pt0->lshift(1, pt1.get()); pt0->lshift(1, pt1.get());
EXPECT_EQ(4, pt1->data()[0]); EXPECT_EQ(4, pt1->data()[0]);
} }
TEST_F(PaddleTensorTest, rshift_test) { TEST_F(PaddleTensorTest, rshift_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 2; pt0->data()[0] = 2;
pt0->rshift(1, pt1.get()); pt0->rshift(1, pt1.get());
EXPECT_EQ(1, pt1->data()[0]); EXPECT_EQ(1, pt1->data()[0]);
} }
TEST_F(PaddleTensorTest, logical_rshift_test) { TEST_F(PaddleTensorTest, logical_rshift_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = -1; pt0->data()[0] = -1;
pt0->logical_rshift(1, pt1.get()); pt0->logical_rshift(1, pt1.get());
EXPECT_EQ(-1ull >> 1, pt1->data()[0]); EXPECT_EQ(-1ull >> 1, pt1->data()[0]);
} }
TEST_F(PaddleTensorTest, scale_test) { TEST_F(PaddleTensorTest, scale_test) {
auto pt = _tensor_factory->template create<int64_t>(); auto pt = _tensor_factory->template create<int64_t>();
auto pt_ = dynamic_cast<PaddleTensor<int64_t> *>(pt.get()); auto pt_ = dynamic_cast<PaddleTensor<int64_t>*>(pt.get());
pt_->scaling_factor() = 1; pt_->scaling_factor() = 1;
Tensor t; Tensor t;
int dim[1] = {1}; int dim[1] = { 1 };
paddle::framework::DDim ddim(dim, 1); paddle::framework::DDim ddim(dim, 1);
t.template mutable_data<float>(ddim, _cpu_ctx.GetPlace()); t.template mutable_data<float>(ddim, _cpu_ctx.GetPlace());
t.template data<float>()[0] = 0.25f; t.template data<float>()[0] = 0.25f;
pt_->template from_float_point_type<float>(t, 2); pt_->template from_float_point_type<float>(t, 2);
EXPECT_EQ(2, pt_->scaling_factor()); EXPECT_EQ(2, pt_->scaling_factor());
EXPECT_EQ(1, pt->data()[0]); EXPECT_EQ(1, pt->data()[0]);
} }
TEST_F(PaddleTensorTest, scalar_test) { TEST_F(PaddleTensorTest, scalar_test) {
auto pt = _tensor_factory->template create<int64_t>(); auto pt = _tensor_factory->template create<int64_t>();
auto pt_ = dynamic_cast<PaddleTensor<int64_t> *>(pt.get()); auto pt_ = dynamic_cast<PaddleTensor<int64_t>*>(pt.get());
pt_->scaling_factor() = 1; pt_->scaling_factor() = 1;
std::vector<size_t> shape = {2}; std::vector<size_t> shape = { 2 };
pt_->template from_float_point_scalar(0.25f, shape, 2); pt_->template from_float_point_scalar(0.25f, shape, 2);
EXPECT_EQ(2, pt_->scaling_factor()); EXPECT_EQ(2, pt_->scaling_factor());
EXPECT_EQ(1, pt->data()[0]); EXPECT_EQ(1, pt->data()[0]);
EXPECT_EQ(1, pt->data()[1]); EXPECT_EQ(1, pt->data()[1]);
} }
TEST_F(PaddleTensorTest, slice_test) { TEST_F(PaddleTensorTest, slice_test) {
std::vector<size_t> shape = {2, 2}; std::vector<size_t> shape = { 2, 2 };
auto pt = _tensor_factory->template create<int64_t>(shape); auto pt = _tensor_factory->template create<int64_t>(shape);
auto ret = _tensor_factory->template create<int64_t>(); auto ret = _tensor_factory->template create<int64_t>();
auto pt_ = dynamic_cast<PaddleTensor<int64_t> *>(pt.get()); auto pt_ = dynamic_cast<PaddleTensor<int64_t>*>(pt.get());
pt_->scaling_factor() = 1; pt_->scaling_factor() = 1;
for (size_t i = 0; i < 4; ++i) { for (size_t i = 0; i < 4; ++i) {
pt->data()[0] = i; pt->data()[0] = i;
} }
pt_->slice(1, 2, ret.get()); pt_->slice(1, 2, ret.get());
auto shape_ = ret->shape(); auto shape_ = ret->shape();
EXPECT_EQ(2, shape_.size()); EXPECT_EQ(2, shape_.size());
EXPECT_EQ(1, shape_[0]); EXPECT_EQ(1, shape_[0]);
EXPECT_EQ(2, shape_[1]); EXPECT_EQ(2, shape_[1]);
EXPECT_EQ(1, ret->scaling_factor()); EXPECT_EQ(1, ret->scaling_factor());
EXPECT_EQ(2, ret->data()[0]); EXPECT_EQ(2, ret->data()[0]);
EXPECT_EQ(3, ret->data()[1]); EXPECT_EQ(3, ret->data()[1]);
} }
} // namespace aby3 } // namespace aby3
...@@ -21,14 +21,13 @@ from paddle.fluid import core ...@@ -21,14 +21,13 @@ from paddle.fluid import core
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from paddle.fluid.framework import convert_np_dtype_to_dtype_ from paddle.fluid.framework import convert_np_dtype_to_dtype_
from paddle.fluid.data_feeder import check_type, check_dtype
class MpcVariable(Variable): class MpcVariable(Variable):
""" """
Extends from paddle.fluid.framework.Variable and rewrite Extends from paddle.fluid.framework.Variable and rewrite
the __init__ method where the shape is resized. the __init__ method where the shape is resized.
""" """
def __init__(self, def __init__(self,
block, block,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
...@@ -91,22 +90,22 @@ class MpcVariable(Variable): ...@@ -91,22 +90,22 @@ class MpcVariable(Variable):
else: else:
old_dtype = self.dtype old_dtype = self.dtype
if dtype != old_dtype: if dtype != old_dtype:
raise ValueError( raise ValueError("MpcVariable {0} has been created before. "
"MpcVariable {0} has been created before. " "The previous data type is {1}; the new "
"The previous data type is {1}; the new " "data type is {2}. They are not "
"data type is {2}. They are not " "matched.".format(self.name, old_dtype,
"matched.".format(self.name, old_dtype, dtype)) dtype))
if lod_level is not None: if lod_level is not None:
if is_new_var: if is_new_var:
self.desc.set_lod_level(lod_level) self.desc.set_lod_level(lod_level)
else: else:
if lod_level != self.lod_level: if lod_level != self.lod_level:
raise ValueError( raise ValueError("MpcVariable {0} has been created before. "
"MpcVariable {0} has been created before. " "The previous lod_level is {1}; the new "
"The previous lod_level is {1}; the new " "lod_level is {2}. They are not "
"lod_level is {2}. They are not " "matched".format(self.name, self.lod_level,
"matched".format(self.name, self.lod_level, lod_level)) lod_level))
if persistable is not None: if persistable is not None:
if is_new_var: if is_new_var:
self.desc.set_persistable(persistable) self.desc.set_persistable(persistable)
...@@ -156,8 +155,7 @@ class MpcParameter(MpcVariable): ...@@ -156,8 +155,7 @@ class MpcParameter(MpcVariable):
if len(shape) == 0: if len(shape) == 0:
raise ValueError( raise ValueError(
"The dimensions of shape for MpcParameter must be greater than 0" "The dimensions of shape for MpcParameter must be greater than 0")
)
for each in shape: for each in shape:
if each < 0: if each < 0:
...@@ -175,8 +173,7 @@ class MpcParameter(MpcVariable): ...@@ -175,8 +173,7 @@ class MpcParameter(MpcVariable):
**kwargs) **kwargs)
self.trainable = kwargs.get('trainable', True) self.trainable = kwargs.get('trainable', True)
self.optimize_attr = kwargs.get('optimize_attr', self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})
{'learning_rate': 1.0})
self.regularizer = kwargs.get('regularizer', None) self.regularizer = kwargs.get('regularizer', None)
...@@ -203,8 +200,8 @@ class MpcParameter(MpcVariable): ...@@ -203,8 +200,8 @@ class MpcParameter(MpcVariable):
additional_attr = ("trainable", "optimize_attr", "regularizer", additional_attr = ("trainable", "optimize_attr", "regularizer",
"gradient_clip_attr", "do_model_average") "gradient_clip_attr", "do_model_average")
for attr_name in additional_attr: for attr_name in additional_attr:
res_str += "%s: %s\n" % ( res_str += "%s: %s\n" % (attr_name,
attr_name, cpt.to_text(getattr(self, attr_name))) cpt.to_text(getattr(self, attr_name)))
else: else:
res_str = MpcVariable.to_string(self, throw_on_error, False) res_str = MpcVariable.to_string(self, throw_on_error, False)
return res_str return res_str
...@@ -245,8 +242,7 @@ def create_mpc_parameter(block, *args, **kwargs): ...@@ -245,8 +242,7 @@ def create_mpc_parameter(block, *args, **kwargs):
init_ops_len = len(init_ops) init_ops_len = len(init_ops)
if init_ops_len > 1: if init_ops_len > 1:
raise RuntimeError("mpc_param " + mpc_param.name + raise RuntimeError("mpc_param " + mpc_param.name +
" is inited by multiple init ops " + str( " is inited by multiple init ops " + str(init_ops))
init_ops))
elif init_ops_len == 1: elif init_ops_len == 1:
# TODO(Paddle 1.7): already inited, do nothing, should log a warning # TODO(Paddle 1.7): already inited, do nothing, should log a warning
pass pass
...@@ -272,7 +268,6 @@ def create_mpc_var(block, *args, **kwargs): ...@@ -272,7 +268,6 @@ def create_mpc_var(block, *args, **kwargs):
kwargs['initializer'](var, block) kwargs['initializer'](var, block)
return var return var
def is_mpc_parameter(var): def is_mpc_parameter(var):
""" """
Check whether the given variable is an instance of MpcParameter. Check whether the given variable is an instance of MpcParameter.
...@@ -282,4 +277,13 @@ def is_mpc_parameter(var): ...@@ -282,4 +277,13 @@ def is_mpc_parameter(var):
bool: True if the given `var` is an instance of Parameter, bool: True if the given `var` is an instance of Parameter,
False if not. False if not.
""" """
return isinstance(var, MpcParameter) return type(var) == MpcParameter
def check_mpc_variable_and_dtype(input,
input_name,
expected_dtype,
op_name,
extra_message=''):
check_type(input, input_name, MpcVariable, op_name, extra_message)
check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
""" """
basic mpc op layers. basic mpc op layers.
""" """
from paddle.fluid.data_feeder import check_type_and_dtype from paddle.fluid.data_feeder import check_variable_and_dtype
from ..framework import MpcVariable from ..framework import MpcVariable
from ..framework import check_mpc_variable_and_dtype
from ..mpc_layer_helper import MpcLayerHelper from ..mpc_layer_helper import MpcLayerHelper
__all__ = [ __all__ = [
...@@ -32,8 +33,8 @@ def _elementwise_op(helper): ...@@ -32,8 +33,8 @@ def _elementwise_op(helper):
assert x is not None, 'x cannot be None in {}'.format(op_type) assert x is not None, 'x cannot be None in {}'.format(op_type)
assert y is not None, 'y cannot be None in {}'.format(op_type) assert y is not None, 'y cannot be None in {}'.format(op_type)
check_type_and_dtype(x, 'x', MpcVariable, ['int64'], op_type) check_mpc_variable_and_dtype(x, 'x', ['int64'], op_type)
check_type_and_dtype(y, 'y', MpcVariable, ['int64'], op_type) check_mpc_variable_and_dtype(y, 'y', ['int64'], op_type)
axis = helper.kwargs.get('axis', -1) axis = helper.kwargs.get('axis', -1)
use_mkldnn = helper.kwargs.get('use_mkldnn', False) use_mkldnn = helper.kwargs.get('use_mkldnn', False)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
""" """
mpc math compare layers. mpc math compare layers.
""" """
from paddle.fluid.data_feeder import check_type_and_dtype
from ..framework import MpcVariable from ..framework import MpcVariable
from ..mpc_layer_helper import MpcLayerHelper from ..mpc_layer_helper import MpcLayerHelper
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
""" """
mpc math op layers. mpc math op layers.
""" """
from paddle.fluid.data_feeder import check_type_and_dtype
from ..framework import MpcVariable from ..framework import MpcVariable
from ..framework import check_mpc_variable_and_dtype
from ..mpc_layer_helper import MpcLayerHelper from ..mpc_layer_helper import MpcLayerHelper
__all__ = [ __all__ = [
...@@ -39,7 +39,7 @@ def mean(x, name=None): ...@@ -39,7 +39,7 @@ def mean(x, name=None):
Examples: todo Examples: todo
""" """
helper = MpcLayerHelper("mean", **locals()) helper = MpcLayerHelper("mean", **locals())
check_type_and_dtype(x, 'x', MpcVariable, ['int64'], 'mean') check_mpc_variable_and_dtype(x, 'x', ['int64'], 'mean')
if name is None: if name is None:
out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype) out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype)
else: else:
...@@ -64,7 +64,7 @@ def square(x, name=None): ...@@ -64,7 +64,7 @@ def square(x, name=None):
Examples: todo Examples: todo
""" """
helper = MpcLayerHelper("square", **locals()) helper = MpcLayerHelper("square", **locals())
check_type_and_dtype(x, 'x', MpcVariable, ['int64'], 'square') check_mpc_variable_and_dtype(x, 'x', ['int64'], 'square')
if name is None: if name is None:
out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype) out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype)
else: else:
...@@ -89,8 +89,7 @@ def sum(x): ...@@ -89,8 +89,7 @@ def sum(x):
Examples: todo Examples: todo
""" """
helper = MpcLayerHelper("sum", **locals()) helper = MpcLayerHelper("sum", **locals())
out = helper.create_mpc_variable_for_type_inference( out = helper.create_mpc_variable_for_type_inference(dtype=helper.input_dtype('x'))
dtype=helper.input_dtype('x'))
helper.append_op( helper.append_op(
type="mpc_sum", type="mpc_sum",
inputs={"X": x}, inputs={"X": x},
...@@ -116,18 +115,16 @@ def square_error_cost(input, label): ...@@ -116,18 +115,16 @@ def square_error_cost(input, label):
Examples: todo Examples: todo
""" """
helper = MpcLayerHelper('square_error_cost', **locals()) helper = MpcLayerHelper('square_error_cost', **locals())
minus_out = helper.create_mpc_variable_for_type_inference( minus_out = helper.create_mpc_variable_for_type_inference(dtype=input.dtype)
dtype=input.dtype)
helper.append_op( helper.append_op(
type='mpc_elementwise_sub', type='mpc_elementwise_sub',
inputs={'X': [input], inputs={'X': [input],
'Y': [label]}, 'Y': [label]},
outputs={'Out': [minus_out]}) outputs={'Out': [minus_out]})
square_out = helper.create_mpc_variable_for_type_inference( square_out = helper.create_mpc_variable_for_type_inference(dtype=input.dtype)
dtype=input.dtype)
helper.append_op( helper.append_op(
type='mpc_square', type='mpc_square',
inputs={'X': [minus_out]}, inputs={'X': [minus_out]},
outputs={'Out': [square_out]}) outputs={'Out': [square_out]})
return square_out return square_out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,12 +14,14 @@ ...@@ -14,12 +14,14 @@
""" """
mpc matrix op layers. mpc matrix op layers.
""" """
from paddle.fluid.data_feeder import check_type_and_dtype
from ..framework import MpcVariable from ..framework import MpcVariable
from ..framework import check_mpc_variable_and_dtype
from ..mpc_layer_helper import MpcLayerHelper from ..mpc_layer_helper import MpcLayerHelper
__all__ = ['mul', ] __all__ = [
'mul',
]
def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
...@@ -61,13 +63,13 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): ...@@ -61,13 +63,13 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
inputs = {"X": [x], "Y": [y]} inputs = {"X": [x], "Y": [y]}
attrs = { attrs = {
"x_num_col_dims": x_num_col_dims, "x_num_col_dims": x_num_col_dims,
"y_num_col_dims": y_num_col_dims "y_num_col_dims": y_num_col_dims
} }
helper = MpcLayerHelper("mul", **locals()) helper = MpcLayerHelper("mul", **locals())
check_type_and_dtype(x, 'x', MpcVariable, ['int64'], 'mul') check_mpc_variable_and_dtype(x, 'x', ['int64'], 'mul')
check_type_and_dtype(y, 'y', MpcVariable, ['int64'], 'mul') check_mpc_variable_and_dtype(y, 'y', ['int64'], 'mul')
if name is None: if name is None:
out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype) out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype)
else: else:
...@@ -75,9 +77,9 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): ...@@ -75,9 +77,9 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
name=name, dtype=x.dtype, persistable=False) name=name, dtype=x.dtype, persistable=False)
helper.append_op( helper.append_op(
type="mpc_mul", type="mpc_mul",
inputs={"X": x, inputs={"X": x,
"Y": y}, "Y": y},
attrs=attrs, attrs=attrs,
outputs={"Out": out}) outputs={"Out": out})
return out return out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -17,9 +17,9 @@ mpc ml op layers. ...@@ -17,9 +17,9 @@ mpc ml op layers.
from functools import reduce from functools import reduce
from paddle.fluid.data_feeder import check_type, check_dtype from paddle.fluid.data_feeder import check_type, check_dtype
from paddle.fluid.data_feeder import check_type_and_dtype
import numpy import numpy
from ..framework import MpcVariable from ..framework import MpcVariable
from ..framework import check_mpc_variable_and_dtype
from ..mpc_layer_helper import MpcLayerHelper from ..mpc_layer_helper import MpcLayerHelper
__all__ = [ __all__ = [
...@@ -30,9 +30,6 @@ __all__ = [ ...@@ -30,9 +30,6 @@ __all__ = [
] ]
# add softmax, relu
def fc(input, def fc(input,
size, size,
num_flatten_dims=1, num_flatten_dims=1,
...@@ -186,8 +183,7 @@ def softmax(input, use_cudnn=False, name=None, axis=-1): ...@@ -186,8 +183,7 @@ def softmax(input, use_cudnn=False, name=None, axis=-1):
""" """
attrs = {"axis": axis, "use_cudnn": use_cudnn} attrs = {"axis": axis, "use_cudnn": use_cudnn}
helper = MpcLayerHelper('softmax', **locals()) helper = MpcLayerHelper('softmax', **locals())
check_type_and_dtype(input, 'input', MpcVariable, check_mpc_variable_and_dtype(input, 'input', ['int64'], 'softmax')
['float16', 'float32', 'float64'], 'softmax')
dtype = helper.input_dtype() dtype = helper.input_dtype()
mpc_softmax_out = helper.create_mpc_variable_for_type_inference(dtype) mpc_softmax_out = helper.create_mpc_variable_for_type_inference(dtype)
...@@ -226,7 +222,9 @@ def relu(input, name=None): ...@@ -226,7 +222,9 @@ def relu(input, name=None):
dtype = helper.input_dtype(input_param_name='input') dtype = helper.input_dtype(input_param_name='input')
out = helper.create_mpc_variable_for_type_inference(dtype) out = helper.create_mpc_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="mpc_relu", inputs={"X": input}, outputs={"Y": out}) type="mpc_relu",
inputs={"X": input},
outputs={"Y": out})
return out return out
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -32,7 +32,6 @@ def monkey_patch_mpc_variable(): ...@@ -32,7 +32,6 @@ def monkey_patch_mpc_variable():
Monkey patch for operator overloading. Monkey patch for operator overloading.
:return: :return:
""" """
def unique_tmp_name(): def unique_tmp_name():
""" """
Generate temp name for variable. Generate temp name for variable.
...@@ -80,7 +79,9 @@ def monkey_patch_mpc_variable(): ...@@ -80,7 +79,9 @@ def monkey_patch_mpc_variable():
tmp_name = unique_tmp_name() tmp_name = unique_tmp_name()
return block.create_var(name=tmp_name, dtype=dtype) return block.create_var(name=tmp_name, dtype=dtype)
def _elemwise_method_creator_(method_name, op_type, reverse=False): def _elemwise_method_creator_(method_name,
op_type,
reverse=False):
""" """
Operator overloading for different method. Operator overloading for different method.
:param method_name: the name of operator which is overloaded. :param method_name: the name of operator which is overloaded.
...@@ -88,19 +89,16 @@ def monkey_patch_mpc_variable(): ...@@ -88,19 +89,16 @@ def monkey_patch_mpc_variable():
:param reverse: :param reverse:
:return: :return:
""" """
def __impl__(self, other_var): def __impl__(self, other_var):
lhs_dtype = safe_get_dtype(self) lhs_dtype = safe_get_dtype(self)
if method_name in compare_ops: if method_name in compare_ops:
if not isinstance(other_var, Variable): if not isinstance(other_var, Variable):
raise NotImplementedError( raise NotImplementedError("Unsupported data type of {} for compare operations."
"Unsupported data type of {} for compare operations." .format(other_var.name))
.format(other_var.name))
else: else:
if not isinstance(other_var, MpcVariable): if not isinstance(other_var, MpcVariable):
raise NotImplementedError( raise NotImplementedError("Unsupported data type of {}.".format(other_var.name))
"Unsupported data type of {}.".format(other_var.name))
rhs_dtype = safe_get_dtype(other_var) rhs_dtype = safe_get_dtype(other_var)
if reverse: if reverse:
...@@ -111,8 +109,7 @@ def monkey_patch_mpc_variable(): ...@@ -111,8 +109,7 @@ def monkey_patch_mpc_variable():
if method_name in compare_ops: if method_name in compare_ops:
out = create_new_tmp_var(current_block(self), dtype=rhs_dtype) out = create_new_tmp_var(current_block(self), dtype=rhs_dtype)
else: else:
out = create_new_tmp_mpc_var( out = create_new_tmp_mpc_var(current_block(self), dtype=lhs_dtype)
current_block(self), dtype=lhs_dtype)
# out = create_new_tmp_mpc_var(current_block(self), dtype=lhs_dtype) # out = create_new_tmp_mpc_var(current_block(self), dtype=lhs_dtype)
...@@ -120,9 +117,9 @@ def monkey_patch_mpc_variable(): ...@@ -120,9 +117,9 @@ def monkey_patch_mpc_variable():
if other_var.shape[0] == -1: if other_var.shape[0] == -1:
axis = 0 axis = 0
assert len(self.shape) >= len(other_var.shape), ( assert len(self.shape) >= len(other_var.shape), (
"The rank of the first argument of an binary operator cannot " "The rank of the first argument of an binary operator cannot "
"be smaller than the rank of its second argument: %s vs %s" % "be smaller than the rank of its second argument: %s vs %s" %
(len(self.shape), len(other_var.shape))) (len(self.shape), len(other_var.shape)))
current_block(self).append_op( current_block(self).append_op(
type=op_type, type=op_type,
...@@ -157,32 +154,33 @@ def monkey_patch_mpc_variable(): ...@@ -157,32 +154,33 @@ def monkey_patch_mpc_variable():
# inject methods # inject methods
for method_name, op_type, reverse in ( for method_name, op_type, reverse in (
("__add__", "mpc_elementwise_add", False), ("__add__", "mpc_elementwise_add", False),
# a+b == b+a. Do not need to reverse explicitly # a+b == b+a. Do not need to reverse explicitly
("__radd__", "mpc_elementwise_add", False), ("__radd__", "mpc_elementwise_add", False),
("__sub__", "mpc_elementwise_sub", False), ("__sub__", "mpc_elementwise_sub", False),
("__rsub__", "mpc_elementwise_sub", True), ("__rsub__", "mpc_elementwise_sub", True),
("__mul__", "mpc_elementwise_mul", False), ("__mul__", "mpc_elementwise_mul", False),
# a*b == b*a. Do not need to reverse explicitly # a*b == b*a. Do not need to reverse explicitly
("__rmul__", "mpc_elementwise_mul", False), ("__rmul__", "mpc_elementwise_mul", False),
("__div__", "mpc_elementwise_div", False), ("__div__", "mpc_elementwise_div", False),
("__truediv__", "mpc_elementwise_div", False), ("__truediv__", "mpc_elementwise_div", False),
("__rdiv__", "mpc_elementwise_div", True), ("__rdiv__", "mpc_elementwise_div", True),
("__rtruediv__", "mpc_elementwise_div", True), ("__rtruediv__", "mpc_elementwise_div", True),
("__pow__", "mpc_elementwise_pow", False), ("__pow__", "mpc_elementwise_pow", False),
("__rpow__", "mpc_elementwise_pow", True), ("__rpow__", "mpc_elementwise_pow", True),
("__floordiv__", "mpc_elementwise_floordiv", False), ("__floordiv__", "mpc_elementwise_floordiv", False),
("__mod__", "mpc_elementwise_mod", False), ("__mod__", "mpc_elementwise_mod", False),
# for logical compare # for logical compare
("__eq__", "mpc_equal", False), ("__eq__", "mpc_equal", False),
("__ne__", "mpc_not_equal", False), ("__ne__", "mpc_not_equal", False),
("__lt__", "mpc_less_than", False), ("__lt__", "mpc_less_than", False),
("__le__", "mpc_less_equal", False), ("__le__", "mpc_less_equal", False),
("__gt__", "mpc_greater_than", False), ("__gt__", "mpc_greater_than", False),
("__ge__", "mpc_greater_equal", False)): ("__ge__", "mpc_greater_equal", False)
):
# Not support computation between MpcVariable and scalar. # Not support computation between MpcVariable and scalar.
setattr(MpcVariable, method_name, setattr(MpcVariable,
method_name,
_elemwise_method_creator_(method_name, op_type, reverse) _elemwise_method_creator_(method_name, op_type, reverse)
if method_name in supported_mpc_ops else announce_not_impl) if method_name in supported_mpc_ops else announce_not_impl)
# MpcVariable.astype = astype
...@@ -34,7 +34,7 @@ def python_version(): ...@@ -34,7 +34,7 @@ def python_version():
max_version, mid_version, min_version = python_version() max_version, mid_version, min_version = python_version()
REQUIRED_PACKAGES = [ REQUIRED_PACKAGES = [
'six >= 1.10.0', 'protobuf >= 3.1.0', 'paddlepaddle == 1.6.3', 'six >= 1.10.0', 'protobuf >= 3.1.0', 'paddlepaddle >= 1.8.0', 'paddlepaddle-gpu >= 1.8'
] ]
if max_version < 3: if max_version < 3:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册