提交 072ccc32 编写于 作者: J jhjiangcs

improve code to support PaddlePaddle1.8.0.

上级 639e920b
......@@ -34,8 +34,8 @@ execute_process(COMMAND ${PYTHON} -c "import paddle;print(paddle.version.full_ve
RESULT_VARIABLE ret OUTPUT_VARIABLE paddle_version OUTPUT_STRIP_TRAILING_WHITESPACE)
if (NOT ret)
if (NOT ${paddle_version} STREQUAL "1.6.3")
message(FATAL_ERROR "Paddle installation of 1.6.3 is required but ${paddle_version} is found")
if (NOT ${paddle_version} STREQUAL "1.8.0")
message(FATAL_ERROR "Paddle installation of 1.8.0 is required but ${paddle_version} is found")
endif()
else()
message(FATAL_ERROR "Could not get paddle version.")
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mpc_compare_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_compare_op.h"
namespace paddle {
namespace operators {
......@@ -23,85 +23,73 @@ using Tensor = framework::Tensor;
class MpcCompareOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of MpcCompareOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true,
platform::errors::NotFound(
"Input(Y) of MpcCompareOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of MpcCompareOp should not be null."));
auto dim_x = ctx->GetInputDim("X");
auto dim_y = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(dim_x.size(), dim_y.size(),
"The size of dim_y should not be greater than dim_x's.");
ctx->ShareDim("Y", /*->*/ "Out");
ctx->ShareLoD("Y", /*->*/ "Out");
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of MpcCompareOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound("Input(Y) of MpcCompareOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of MpcCompareOp should not be null."));
auto dim_x = ctx->GetInputDim("X");
auto dim_y = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(dim_x.size(), dim_y.size(),
"The size of dim_y should not be greater than dim_x's.");
ctx->ShareDim("Y", /*->*/ "Out");
ctx->ShareLoD("Y", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
framework::OpKernelType
GetExpectedKernelType(const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class MpcCompareOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The first input tensor of MpcCompareOp.");
AddInput("Y", "(Tensor), The second input tensor of MpcCompareOp.");
AddOutput("Out", "(Tensor), The output tensor of MpcCompareOp.");
AddComment(R"DOC(
void Make() override {
AddInput("X", "(Tensor), The first input tensor of MpcCompareOp.");
AddInput("Y", "(Tensor), The second input tensor of MpcCompareOp.");
AddOutput("Out", "(Tensor), The output tensor of MpcCompareOp.");
AddComment(R"DOC(
MPC Compare Operator.
)DOC");
}
}
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_than, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_greater_than,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::MpcGreaterThanFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_equal, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_greater_equal,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::MpcGreaterEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_less_than, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_less_than, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MpcLessThanFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_less_equal, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_less_equal, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MpcLessEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_equal, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_equal, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MpcEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_not_equal, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_not_equal, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MpcNotEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_than, ops::MpcCompareOp, ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(mpc_greater_than,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcGreaterThanFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(mpc_greater_equal,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcGreaterEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_less_than, ops::MpcCompareOp, ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(mpc_less_than,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcLessThanFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_less_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(mpc_less_equal,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcLessEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(mpc_equal,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_not_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(mpc_not_equal,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcNotEqualFunctor>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.uage governing permissions and
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include <math.h>
#include <type_traits>
namespace paddle {
namespace operators {
......@@ -24,58 +21,52 @@ namespace operators {
using Tensor = framework::Tensor;
struct MpcGreaterThanFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->gt(
in_x_t, in_y_t, out_t);
}
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->gt(in_x_t, in_y_t, out_t);
}
};
struct MpcGreaterEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->geq(
in_x_t, in_y_t, out_t);
}
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->geq(in_x_t, in_y_t, out_t);
}
};
struct MpcLessThanFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->lt(
in_x_t, in_y_t, out_t);
}
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->lt(in_x_t, in_y_t, out_t);
}
};
struct MpcLessEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->leq(
in_x_t, in_y_t, out_t);
}
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->leq(in_x_t, in_y_t, out_t);
}
};
struct MpcEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->eq(
in_x_t, in_y_t, out_t);
}
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->eq(in_x_t, in_y_t, out_t);
}
};
struct MpcNotEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neq(
in_x_t, in_y_t, out_t);
}
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neq(in_x_t, in_y_t, out_t);
}
};
template <typename DeviceContext, typename T, typename Functor>
class MpcCompareOpKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<framework::LoDTensor>("X");
auto *in_y_t = ctx.Input<framework::LoDTensor>("Y");
auto *out_t = ctx.Output<framework::LoDTensor>("Out");
void ComputeImpl(const framework::ExecutionContext &ctx) const override{
auto *in_x_t = ctx.Input<framework::LoDTensor>("X");
auto *in_y_t = ctx.Input<framework::LoDTensor>("Y");
auto *out_t = ctx.Output<framework::LoDTensor>("Out");
auto out = out_t->mutable_data<T>(ctx.GetPlace());
Functor().Run(in_x_t, in_y_t, out_t);
}
auto out = out_t->mutable_data<T>(ctx.GetPlace());
Functor().Run(in_x_t, in_y_t, out_t);
}
};
} // namespace operators
} // namespace paddl
} // namespace operators
} // namespace paddl
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mpc_elementwise_add_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_elementwise_add_op.h"
namespace paddle {
namespace operators {
......@@ -22,111 +22,105 @@ using Tensor = framework::Tensor;
class MpcElementwiseAddOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound(
"Input(Y) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_GE(
ctx->GetInputDim("X").size(), ctx->GetInputDim("Y").size(),
platform::errors::InvalidArgument(
"The dimensions of X should be greater than the dimensions of Y. "
"But received the dimensions of X is [%s], the dimensions of Y is "
"[%s]",
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound("Input(Y) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_GE(
ctx->GetInputDim("X").size(), ctx->GetInputDim("Y").size(),
platform::errors::InvalidArgument(
"The dimensions of X should be greater than the dimensions of Y. "
"But received the dimensions of X is [%s], the dimensions of Y is [%s]",
ctx->GetInputDim("X"), ctx->GetInputDim("Y")));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class MpcElementwiseAddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), The first input tensor of mpc elementwise add op.");
AddInput("Y",
"(Tensor), The second input tensor of mpc elementwise add op.");
AddOutput("Out", "(Tensor), The output tensor of mpc elementwise add op.");
AddAttr<int>("axis",
void Make() override {
AddInput("X", "(Tensor), The first input tensor of mpc elementwise add op.");
AddInput("Y", "(Tensor), The second input tensor of mpc elementwise add op.");
AddOutput("Out", "(Tensor), The output tensor of mpc elementwise add op.");
AddAttr<int>("axis",
"(int, default -1). If X.dimension != Y.dimension,"
"Y.dimension must be a subsequence of x.dimension. And axis "
"is the start dimension index "
"for broadcasting Y onto X. ")
.SetDefault(-1)
.EqualGreaterThan(-1);
AddComment(R"DOC(
AddComment(R"DOC(
MPC elementwise add Operator.
)DOC");
}
}
};
class MpcElementwiseAddGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true,
using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true,
"Input(Out@GRAD) should not be null.");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->ShareDim("X", /*->*/ x_grad_name);
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->ShareDim("Y", /*->*/ y_grad_name);
ctx->ShareLoD("Y", /*->*/ y_grad_name);
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->ShareDim("X", /*->*/ x_grad_name);
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->ShareDim("Y", /*->*/ y_grad_name);
ctx->ShareLoD("Y", /*->*/ y_grad_name);
}
}
}
};
template <typename T>
class MpcElementwiseAddOpGradMaker : public framework::SingleGradOpDescMaker {
class MpcElementwiseAddOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
retv->SetType("mpc_elementwise_add_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Y", this->Input("Y"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetAttrMap(this->Attrs());
return retv;
}
void Apply(GradOpPtr<T> grad) const override {
grad->SetType("mpc_elementwise_add_grad");
grad->SetInput("X", this->Input("X"));
grad->SetInput("Y", this->Input("Y"));
grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
grad->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_elementwise_add, ops::MpcElementwiseAddOp,
ops::MpcElementwiseAddOpMaker,
ops::MpcElementwiseAddOpGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_elementwise_add, ops::MpcElementwiseAddOp,
ops::MpcElementwiseAddOpMaker,
ops::MpcElementwiseAddOpGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_elementwise_add_grad, ops::MpcElementwiseAddGradOp);
REGISTER_OPERATOR(mpc_elementwise_add_grad, ops::MpcElementwiseAddGradOp);
REGISTER_OP_CPU_KERNEL(
mpc_elementwise_add,
mpc_elementwise_add,
ops::MpcElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_elementwise_add_grad,
ops::MpcElementwiseAddGradKernel<
paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
mpc_elementwise_add_grad,
ops::MpcElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// This op is different with elementwise_add of PaddlePaddle.
// We only consider that the dimensions of X is equal with the dimensions of Y.
......@@ -18,7 +18,6 @@
#pragma once
#include "mpc_op.h"
#include "paddle/fluid/platform/transform.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle {
namespace operators {
......@@ -26,189 +25,187 @@ namespace operators {
using Tensor = framework::Tensor;
// paddle/fluid/operators/elementwise/elementwise_op_function.h
template <typename T, typename DeviceContext> class RowwiseTransformIterator;
template <typename T, typename DeviceContext>
class RowwiseTransformIterator;
template <typename T>
class RowwiseTransformIterator<T, platform::CPUDeviceContext>
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t,
T *, T &> {
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t, T *, T &> {
public:
RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator++() {
++i_;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator++() {
++i_;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
return *this;
}
return *this;
}
RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator+(int n) {
while (n-- > 0) {
++i_;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
}
RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator+(int n) {
while (n-- > 0) {
++i_;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
}
}
return *this;
}
return *this;
}
bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>
&rhs) const {
return (ptr_ + i_) == &(*rhs);
}
bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext> &rhs) const {
return (ptr_ + i_) == &(*rhs);
}
bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>
&rhs) const {
return (ptr_ + i_) != &(*rhs);
}
bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext> &rhs) const {
return (ptr_ + i_) != &(*rhs);
}
const T &operator*() { return ptr_[i_]; }
const T &operator*() { return ptr_[i_]; }
private:
const T *ptr_;
int i_;
int64_t n_;
const T *ptr_;
int i_;
int64_t n_;
};
template <typename T> struct AddFunctor {
inline HOSTDEVICE T operator()(T x, T y) { return x + y; }
template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(T x, T y) { return x + y; }
};
struct GetMidDims {
inline HOSTDEVICE void operator()(const framework::DDim &x_dims,
const framework::DDim &y_dims,
const int axis, int *pre, int *n,
int *post) {
*pre = 1;
*n = 1;
*post = 1;
for (int i = 1; i < axis + 1; ++i) {
(*pre) *= x_dims[i];
}
inline HOSTDEVICE void operator()(const framework::DDim &x_dims,
const framework::DDim &y_dims, const int axis,
int *pre, int *n, int *post) {
*pre = 1;
*n = 1;
*post = 1;
for (int i = 1; i < axis + 1; ++i) {
(*pre) *= x_dims[i];
}
for (int i = 1; i < y_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
"Broadcast dimension mismatch.");
(*n) *= y_dims[i];
}
for (int i = 1; i < y_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
"Broadcast dimension mismatch.");
(*n) *= y_dims[i];
}
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
}
}
}
};
const size_t SHARE_NUM = 2;
const size_t SHARE_NUM = 2;
template <typename DeviceContext, typename T>
class MpcElementwiseAddKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<Tensor>("X");
auto *in_y_t = ctx.Input<Tensor>("Y");
auto *out_t = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int>("axis");
auto out = out_t->mutable_data<T>(ctx.GetPlace());
if (in_x_t->dims() == in_y_t->dims()) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(
in_x_t, in_y_t, out_t);
} else {
Tensor in_x_t_slice;
Tensor in_y_t_slice;
Tensor out_t_slice;
for (size_t i = 0; i < SHARE_NUM; ++i) {
in_x_t_slice = in_x_t->Slice(i, i + 1);
in_y_t_slice = in_y_t->Slice(i, i + 1);
out_t_slice = out_t->Slice(i, i + 1);
auto x_dims = in_x_t_slice.dims();
auto y_dims = in_y_t_slice.dims();
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
int pre, n, post;
GetMidDims get_mid_dims;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
PADDLE_ENFORCE_EQ(
post, 1, "post should be equal 1, but received post is [%s]", post);
auto x_ = in_x_t_slice.data<T>();
auto y_ = in_y_t_slice.data<T>();
auto out_ = out_t_slice.data<T>();
auto nx_ = in_x_t_slice.numel();
paddle::platform::Transform<DeviceContext> trans;
trans(ctx.template device_context<DeviceContext>(), x_, x_ + nx_,
RowwiseTransformIterator<T, DeviceContext>(y_, n), out_,
AddFunctor<T>());
}
}
void ComputeImpl(const framework::ExecutionContext &ctx) const override{
auto *in_x_t = ctx.Input<framework::LoDTensor>("X");
auto *in_y_t = ctx.Input<framework::LoDTensor>("Y");
auto *out_t = ctx.Output<framework::LoDTensor>("Out");
int axis = ctx.Attr<int>("axis");
auto out = out_t->mutable_data<T>(ctx.GetPlace());
if (in_x_t->dims() == in_y_t->dims()) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(in_x_t, in_y_t, out_t);
} else {
Tensor in_x_t_slice;
Tensor in_y_t_slice;
Tensor out_t_slice;
for (size_t i = 0; i < SHARE_NUM; ++i) {
in_x_t_slice = in_x_t->Slice(i, i + 1);
in_y_t_slice = in_y_t->Slice(i, i + 1);
out_t_slice = out_t->Slice(i, i + 1);
auto x_dims = in_x_t_slice.dims();
auto y_dims = in_y_t_slice.dims();
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
int pre, n, post;
GetMidDims get_mid_dims;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
PADDLE_ENFORCE_EQ(post, 1,
"post should be equal 1, but received post is [%s]", post);
auto x_ = in_x_t_slice.data<T>();
auto y_ = in_y_t_slice.data<T>();
auto out_ = out_t_slice.data<T>();
auto nx_ = in_x_t_slice.numel();
paddle::platform::Transform<DeviceContext> trans;
trans(ctx.template device_context<DeviceContext>(), x_, x_ + nx_,
RowwiseTransformIterator<T, DeviceContext>(y_, n),
out_, AddFunctor<T>());
}
}
}
};
template <typename DeviceContext, typename T>
class MpcElementwiseAddGradKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<framework::LoDTensor>("X");
auto *in_y_t = ctx.Input<framework::LoDTensor>("Y");
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
auto dout_data = dout->data<T>();
if (dx) {
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < dout->numel(); i++) {
dx_data[i] = dout_data[i];
}
}
if (dy) {
auto dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (in_x_t->dims().size() == in_y_t->dims().size()) {
for (size_t i = 0; i < dout->numel(); i++) {
dy_data[i] = dout_data[i];
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<framework::LoDTensor>("X");
auto *in_y_t = ctx.Input<framework::LoDTensor>("Y");
auto *dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
auto dout_data = dout->data<T>();
if (dx) {
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < dout->numel(); i++) {
dx_data[i] = dout_data[i];
}
}
} else {
auto x_dims = in_x_t->dims();
auto y_dims = in_y_t->dims();
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
int pre, n, post;
GetMidDims get_mid_dims;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
PADDLE_ENFORCE_EQ(
post, 1, "post should be equal 1, but received post is [%s]", post);
for (size_t i = 0; i < SHARE_NUM; ++i) {
int y_offset = i * n;
for (size_t j = 0; j < pre; ++j) {
for (size_t k = 0; k < n; ++k) {
int out_offset = i * pre * n + j * n + k;
if (0 == j) {
dy_data[k + y_offset] = dout_data[out_offset];
} else {
dy_data[k + y_offset] += dout_data[out_offset];
}
if (dy) {
auto dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (in_x_t->dims().size() == in_y_t->dims().size()) {
for (size_t i = 0; i < dout->numel(); i++) {
dy_data[i] = dout_data[i];
}
} else {
auto x_dims = in_x_t->dims();
auto y_dims = in_y_t->dims();
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
int pre, n, post;
GetMidDims get_mid_dims;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
PADDLE_ENFORCE_EQ(post, 1,
"post should be equal 1, but received post is [%s]", post);
for (size_t i = 0; i < SHARE_NUM; ++i) {
int y_offset = i * n;
for (size_t j = 0; j < pre; ++j) {
for (size_t k = 0; k < n; ++k) {
int out_offset = i * pre * n + j * n + k;
if (0 == j) {
dy_data[k + y_offset] = dout_data[out_offset];
} else {
dy_data[k + y_offset] += dout_data[out_offset];
}
}
}
}
}
}
}
}
}
}
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mpc_elementwise_sub_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_elementwise_sub_op.h"
namespace paddle {
namespace operators {
class MpcElementwiseSubOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound(
"Input(Y) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("X"), ctx->GetInputDim("Y"),
platform::errors::InvalidArgument(
"The dimensions of X should be equal with the dimensions of Y. "
"But received the dimensions of X is [%s], the dimensions of Y is "
"[%s]",
ctx->GetInputDim("X"), ctx->GetInputDim("Y")));
using framework::OperatorWithKernel::OperatorWithKernel;
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound("Input(Y) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("X"), ctx->GetInputDim("Y"),
platform::errors::InvalidArgument(
"The dimensions of X should be equal with the dimensions of Y. "
"But received the dimensions of X is [%s], the dimensions of Y is [%s]",
ctx->GetInputDim("X"), ctx->GetInputDim("Y")));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class MpcElementwiseSubOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), The first input tensor of mpc elementwise sub op.");
AddInput("Y",
"(Tensor), The second input tensor of mpc elementwise sub op.");
AddOutput("Out", "(Tensor), The output tensor of mpc elementwise sub op.");
AddComment(R"DOC(
void Make() override {
AddInput("X", "(Tensor), The first input tensor of mpc elementwise sub op.");
AddInput("Y", "(Tensor), The second input tensor of mpc elementwise sub op.");
AddOutput("Out", "(Tensor), The output tensor of mpc elementwise sub op.");
AddComment(R"DOC(
MPC elementwise sub Operator.
)DOC");
}
}
};
class MpcElementwiseSubGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true,
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true,
"Input(Out@GRAD) should not be null.");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->ShareDim("X", /*->*/ x_grad_name);
ctx->ShareLoD("X", /*->*/ x_grad_name);
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->ShareDim("X", /*->*/ x_grad_name);
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->ShareDim("Y", /*->*/ y_grad_name);
ctx->ShareLoD("Y", /*->*/ y_grad_name);
}
}
if (ctx->HasOutput(y_grad_name)) {
ctx->ShareDim("Y", /*->*/ y_grad_name);
ctx->ShareLoD("Y", /*->*/ y_grad_name);
}
}
};
template <typename T>
class MpcElementwiseSubGradMaker : public framework::SingleGradOpDescMaker {
class MpcElementwiseSubGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
retv->SetType("mpc_elementwise_sub_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Y", this->Input("Y"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetAttrMap(this->Attrs());
return retv;
}
void Apply(GradOpPtr<T> grad) const override {
grad->SetType("mpc_elementwise_sub_grad");
grad->SetInput("X", this->Input("X"));
grad->SetInput("Y", this->Input("Y"));
grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
grad->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_elementwise_sub, ops::MpcElementwiseSubOp,
ops::MpcElementwiseSubOpMaker,
ops::MpcElementwiseSubGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_elementwise_sub, ops::MpcElementwiseSubOp,
ops::MpcElementwiseSubOpMaker,
ops::MpcElementwiseSubGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_elementwise_sub_grad, ops::MpcElementwiseSubGradOp);
REGISTER_OPERATOR(mpc_elementwise_sub_grad, ops::MpcElementwiseSubGradOp);
REGISTER_OP_CPU_KERNEL(
mpc_elementwise_sub,
mpc_elementwise_sub,
ops::MpcElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_elementwise_sub_grad,
ops::MpcElementwiseSubGradKernel<
paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
mpc_elementwise_sub_grad,
ops::MpcElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// This op is different with elementwise_sub of PaddlePaddle.
// We only consider that the dimensions of X is equal with the dimensions of Y.
#pragma once
#include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle {
namespace operators {
......@@ -27,40 +26,39 @@ using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class MpcElementwiseSubKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<Tensor>("X");
auto *in_y_t = ctx.Input<Tensor>("Y");
auto *out_t = ctx.Output<Tensor>("Out");
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<Tensor>("X");
auto *in_y_t = ctx.Input<Tensor>("Y");
auto *out_t = ctx.Output<Tensor>("Out");
auto out = out_t->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(
in_x_t, in_y_t, out_t);
}
auto out = out_t->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(in_x_t, in_y_t, out_t);
}
};
template <typename DeviceContext, typename T>
class MpcElementwiseSubGradKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
VLOG(3) << "******** MpcElementwiseSubGradKernel: ";
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto dout_data = dout->data<T>();
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
VLOG(3) << "******** MpcElementwiseSubGradKernel: ";
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto dout_data = dout->data<T>();
if (dx) {
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < dout->numel(); i++) {
dx_data[i] = dout_data[i];
}
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neg(
dout, dy);
if (dx) {
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < dout->numel(); i++) {
dx_data[i] = dout_data[i];
}
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neg(dout, dy);
}
}
}
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Description:
#include "paddle/fluid/framework/op_registry.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_config.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_config.h"
namespace paddle {
namespace operators {
......@@ -26,59 +26,63 @@ using mpc::Aby3Config;
class MpcInitOp : public framework::OperatorBase {
public:
MpcInitOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto protocol_name = Attr<std::string>("protocol_name");
auto role = Attr<int>("role");
auto local_addr = Attr<std::string>("local_addr");
auto net_server_addr = Attr<std::string>("net_server_addr");
auto net_server_port = Attr<int>("net_server_port");
MpcConfig _mpc_config;
_mpc_config.set_int(Aby3Config::ROLE, role);
_mpc_config.set(Aby3Config::LOCAL_ADDR, local_addr);
_mpc_config.set(Aby3Config::NET_SERVER_ADDR, net_server_addr);
_mpc_config.set_int(Aby3Config::NET_SERVER_PORT, net_server_port);
mpc::MpcInstance::init_instance(protocol_name, _mpc_config);
}
MpcInitOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto protocol_name = Attr<std::string>("protocol_name");
auto role = Attr<int>("role");
auto local_addr = Attr<std::string>("local_addr");
auto net_server_addr = Attr<std::string>("net_server_addr");
auto net_server_port = Attr<int>("net_server_port");
MpcConfig _mpc_config;
_mpc_config.set_int(Aby3Config::ROLE, role);
_mpc_config.set(Aby3Config::LOCAL_ADDR, local_addr);
_mpc_config.set(Aby3Config::NET_SERVER_ADDR, net_server_addr);
_mpc_config.set_int(Aby3Config::NET_SERVER_PORT, net_server_port);
mpc::MpcInstance::init_instance(protocol_name, _mpc_config);
}
};
class MpcInitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
void Make() override {
AddComment(R"DOC(
AddComment(R"DOC(
Where2 Operator.
)DOC");
AddAttr<std::string>("protocol_name", "(string , default aby3)"
"protocol name")
AddAttr<std::string>("protocol_name",
"(string , default aby3)"
"protocol name")
.SetDefault({"aby3"});
AddAttr<int>("role", "trainer role.").SetDefault(0);
AddAttr<std::string>("local_addr", "(string, default localhost)"
"local addr")
AddAttr<int>("role", "trainer role.").SetDefault(0);
AddAttr<std::string>("local_addr",
"(string, default localhost)"
"local addr")
.SetDefault({"localhost"});
AddAttr<std::string>("net_server_addr", "(string, default localhost)"
"net server addr")
AddAttr<std::string>("net_server_addr",
"(string, default localhost)"
"net server addr")
.SetDefault({"localhost"});
AddAttr<int>("net_server_port", "net server port, default to 6539.")
.SetDefault(6539);
}
AddAttr<int>("net_server_port", "net server port, default to 6539.").SetDefault(6539);
}
};
class MpcInitOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_init, ops::MpcInitOp, ops::MpcInitOpMaker,
ops::MpcInitOpShapeInference);
REGISTER_OPERATOR(
mpc_init, ops::MpcInitOp,
ops::MpcInitOpMaker, ops::MpcInitOpShapeInference);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mpc_mean_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_mean_op.h"
namespace paddle {
namespace operators {
......@@ -22,78 +22,80 @@ using Tensor = framework::Tensor;
class MpcMeanOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of MpcMeanOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of MpcMeanOp should not be null."));
ctx->SetOutputDim("Out", {2, 1});
}
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of MpcMeanOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of MpcMeanOp should not be null."));
ctx->SetOutputDim("Out", {2, 1});
}
};
class MpcMeanOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The first input tensor of mpc mean op.");
AddOutput("Out", "(Tensor), The output tensor of mpc mean op.");
AddComment(R"DOC(
void Make() override {
AddInput("X", "(Tensor), The first input tensor of mpc mean op.");
AddOutput("Out", "(Tensor), The output tensor of mpc mean op.");
AddComment(R"DOC(
MPC mean Operator calculates the mean of all elements in X.
)DOC");
}
}
};
class MpcMeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>
GetInputOutputWithSameType() const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
class MpcMeanGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor;
using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X"));
}
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X"));
}
};
template <typename T>
class MpcMeanOpGradMaker : public framework::SingleGradOpDescMaker {
class MpcMeanOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
retv->SetType("mpc_mean_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
return retv;
}
void Apply(GradOpPtr<T> grad) const override {
grad->SetType("mpc_mean_grad");
grad->SetInput("X", this->Input("X"));
grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_mean, ops::MpcMeanOp, ops::MpcMeanOpMaker,
ops::MpcMeanOpInferVarType,
ops::MpcMeanOpGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_mean, ops::MpcMeanOp,
ops::MpcMeanOpMaker,
ops::MpcMeanOpInferVarType,
ops::MpcMeanOpGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_mean_grad, ops::MpcMeanGradOp);
REGISTER_OP_CPU_KERNEL(
mpc_mean, ops::MpcMeanKernel<paddle::platform::CPUDeviceContext, int64_t>);
mpc_mean,
ops::MpcMeanKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
mpc_mean_grad,
mpc_mean_grad,
ops::MpcMeanGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "mpc_op.h"
#include "paddle/fluid/framework/eigen.h"
namespace paddle {
namespace operators {
......@@ -28,43 +27,40 @@ using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename DeviceContext, typename T>
class MpcMeanKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<Tensor>("X");
auto *out_t = ctx.Output<Tensor>("Out");
out_t->mutable_data<T>(ctx.GetPlace());
double scale = 1.0 / (in_x_t->numel() / 2.0);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sum(
in_x_t, out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(
out_t, scale, out_t);
}
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<Tensor>("X");
auto *out_t = ctx.Output<Tensor>("Out");
out_t->mutable_data<T>(ctx.GetPlace());
double scale = 1.0 / (in_x_t->numel() / 2.0);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sum(in_x_t, out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(out_t, scale, out_t);
}
};
template <typename DeviceContext, typename T>
class MpcMeanGradKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(dout->numel() == 2,
"numel of MpcMean Gradient should be 2.");
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dout_data = dout->data<T>();
if (dx) {
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < dx->numel() / 2; ++i) {
dx_data[i] = dout_data[0];
}
for (size_t i = dx->numel() / 2; i < dx->numel(); ++i) {
dx_data[i] = dout_data[1];
}
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(dout->numel() == 2, "numel of MpcMean Gradient should be 2.");
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dout_data = dout->data<T>();
double scale_factor = 1.0 / (dx->numel() / 2);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(
dx, scale_factor, dx);
if (dx) {
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
for (size_t i = 0; i < dx->numel() / 2; ++i) {
dx_data[i] = dout_data[0];
}
for (size_t i = dx->numel() / 2; i < dx->numel(); ++i) {
dx_data[i] = dout_data[1];
}
double scale_factor = 1.0 / (dx->numel() / 2);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(dx, scale_factor, dx);
}
}
}
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mpc_mul_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_mul_op.h"
namespace paddle {
namespace operators {
......@@ -22,98 +22,98 @@ using Tensor = framework::Tensor;
class MpcMulOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of Mpc MulOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound("Input(Y) of MpcMulOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of MpcMulOp should not be null."));
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
int x_num_col_dims = ctx->Attrs().Get<int>("x_num_col_dims");
int y_num_col_dims = ctx->Attrs().Get<int>("y_num_col_dims");
VLOG(3) << "mpc mul operator x.shape=" << x_dims << " y.shape=" << y_dims
<< " x_num_col_dims=" << x_num_col_dims
<< " y_num_col_dims=" << y_num_col_dims;
PADDLE_ENFORCE_NE(framework::product(y_dims), 0,
platform::errors::PreconditionNotMet(
"The Input variable Y(%s) has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.",
ctx->Inputs("Y").front()));
PADDLE_ENFORCE_GT(
x_dims.size(), x_num_col_dims,
platform::errors::InvalidArgument(
"The input tensor X's dimensions of MpcMulOp "
"should be larger than x_num_col_dims. But received X's "
"dimensions = %d, X's shape = [%s], x_num_col_dims = %d.",
x_dims.size(), x_dims, x_num_col_dims));
PADDLE_ENFORCE_GT(
y_dims.size(), y_num_col_dims,
platform::errors::InvalidArgument(
"The input tensor Y's dimensions of MpcMulOp "
"should be larger than y_num_col_dims. But received Y's "
"dimensions = %d, Y's shape = [%s], y_num_col_dims = %d.",
y_dims.size(), y_dims, y_num_col_dims));
int x_mat_width = 1;
int y_mat_height = 1;
for (size_t i = x_num_col_dims + 1; i < x_dims.size(); i++) {
x_mat_width *= x_dims[i];
}
for (size_t i = 1; i <= y_num_col_dims; i++) {
y_mat_height *= y_dims[i];
}
PADDLE_ENFORCE_EQ(
x_mat_width, y_mat_height,
platform::errors::InvalidArgument(
"After flatten the input tensor X and Y to 2-D dimensions "
"matrix X1 and Y1, the matrix X1's width must be equal with matrix "
"Y1's height. But received X's shape = [%s], X1's "
"width = %s; Y's shape = [%s], Y1's height = %s.",
x_dims, x_mat_width, y_dims, y_mat_height));
std::vector<int64_t> output_dims;
output_dims.reserve(static_cast<size_t>(1 + x_num_col_dims + y_dims.size() -
y_num_col_dims));
for (int i = 0; i <= x_num_col_dims; ++i) { // i=0, batch_size (share id)
output_dims.push_back(x_dims[i]);
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of Mpc MulOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound("Input(Y) of MpcMulOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of MpcMulOp should not be null."));
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
int x_num_col_dims = ctx->Attrs().Get<int>("x_num_col_dims");
int y_num_col_dims = ctx->Attrs().Get<int>("y_num_col_dims");
VLOG(3) << "mpc mul operator x.shape=" << x_dims << " y.shape=" << y_dims
<< " x_num_col_dims=" << x_num_col_dims
<< " y_num_col_dims=" << y_num_col_dims;
PADDLE_ENFORCE_NE(framework::product(y_dims), 0,
platform::errors::PreconditionNotMet(
"The Input variable Y(%s) has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.",
ctx->Inputs("Y").front()));
PADDLE_ENFORCE_GT(
x_dims.size(), x_num_col_dims,
platform::errors::InvalidArgument(
"The input tensor X's dimensions of MpcMulOp "
"should be larger than x_num_col_dims. But received X's "
"dimensions = %d, X's shape = [%s], x_num_col_dims = %d.",
x_dims.size(), x_dims, x_num_col_dims));
PADDLE_ENFORCE_GT(
y_dims.size(), y_num_col_dims,
platform::errors::InvalidArgument(
"The input tensor Y's dimensions of MpcMulOp "
"should be larger than y_num_col_dims. But received Y's "
"dimensions = %d, Y's shape = [%s], y_num_col_dims = %d.",
y_dims.size(), y_dims, y_num_col_dims));
int x_mat_width = 1;
int y_mat_height = 1;
for (size_t i = x_num_col_dims + 1; i < x_dims.size(); i++) {
x_mat_width *= x_dims[i];
}
for (size_t i = 1; i <= y_num_col_dims; i++) {
y_mat_height *= y_dims[i];
}
PADDLE_ENFORCE_EQ(
x_mat_width, y_mat_height,
platform::errors::InvalidArgument(
"After flatten the input tensor X and Y to 2-D dimensions "
"matrix X1 and Y1, the matrix X1's width must be equal with matrix "
"Y1's height. But received X's shape = [%s], X1's "
"width = %s; Y's shape = [%s], Y1's height = %s.",
x_dims, x_mat_width, y_dims, y_mat_height));
std::vector<int64_t> output_dims;
output_dims.reserve(
static_cast<size_t>(1 + x_num_col_dims + y_dims.size() - y_num_col_dims));
for (int i = 0; i <= x_num_col_dims; ++i) { // i=0, batch_size (share id)
output_dims.push_back(x_dims[i]);
}
for (int i = y_num_col_dims + 1; i < y_dims.size(); ++i) {
output_dims.push_back(y_dims[i]);
}
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
ctx->ShareLoD("X", /*->*/ "Out");
}
for (int i = y_num_col_dims + 1; i < y_dims.size(); ++i) {
output_dims.push_back(y_dims[i]);
}
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class MpcMulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The first input tensor of mpc mul op.");
AddInput("Y", "(Tensor), The second input tensor of mpc mul op.");
AddOutput("Out", "(Tensor), The output tensor of mpc mul op.");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<int>(
"x_num_col_dims",
R"DOC((int, default 1), The mul_op can take tensors with more than two
void Make() override {
AddInput("X", "(Tensor), The first input tensor of mpc mul op.");
AddInput("Y", "(Tensor), The second input tensor of mpc mul op.");
AddOutput("Out", "(Tensor), The output tensor of mpc mul op.");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<int>(
"x_num_col_dims",
R"DOC((int, default 1), The mul_op can take tensors with more than two
dimensions as its inputs. If the input $X$ is a tensor with more
than two dimensions, $X$ will be flattened into a two-dimensional
matrix first. The flattening rule is: the first `num_col_dims`
......@@ -129,109 +129,112 @@ public:
Thus, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] =
[24, 30].
)DOC")
.SetDefault(1)
.EqualGreaterThan(1);
AddAttr<int>(
"y_num_col_dims",
R"DOC((int, default 1), The mul_op can take tensors with more than two,
.SetDefault(1)
.EqualGreaterThan(1);
AddAttr<int>(
"y_num_col_dims",
R"DOC((int, default 1), The mul_op can take tensors with more than two,
dimensions as its inputs. If the input $Y$ is a tensor with more
than two dimensions, $Y$ will be flattened into a two-dimensional
matrix first. The attribute `y_num_col_dims` determines how $Y$ is
flattened. See comments of `x_num_col_dims` for more details.
)DOC")
.SetDefault(1)
.EqualGreaterThan(1);
AddAttr<float>(
"scale_x",
"scale_x to be used for int8 mul input data x. scale_x has the"
"same purpose as scale_in in OPs that support quantization."
"Only to be used with MKL-DNN INT8")
.SetDefault(1.0f);
AddAttr<std::vector<float>>(
"scale_y",
"scale_y to be used for int8 mul input data y. scale_y has the"
"same purpose as scale_weights in OPs that support quantization."
"Only to be used with MKL-DNN INT8")
.SetDefault({1.0f});
AddAttr<float>("scale_out", "scale_out to be used for int8 output data."
"Only used with MKL-DNN INT8")
.SetDefault(1.0f);
AddAttr<bool>(
"force_fp32_output",
"(bool, default false) Force quantize kernel output FP32, only "
"used in quantized MKL-DNN.")
.SetDefault(false);
AddComment(R"DOC(
.SetDefault(1)
.EqualGreaterThan(1);
AddAttr<float>(
"scale_x",
"scale_x to be used for int8 mul input data x. scale_x has the"
"same purpose as scale_in in OPs that support quantization."
"Only to be used with MKL-DNN INT8")
.SetDefault(1.0f);
AddAttr<std::vector<float>>(
"scale_y",
"scale_y to be used for int8 mul input data y. scale_y has the"
"same purpose as scale_weights in OPs that support quantization."
"Only to be used with MKL-DNN INT8")
.SetDefault({1.0f});
AddAttr<float>("scale_out",
"scale_out to be used for int8 output data."
"Only used with MKL-DNN INT8")
.SetDefault(1.0f);
AddAttr<bool>(
"force_fp32_output",
"(bool, default false) Force quantize kernel output FP32, only "
"used in quantized MKL-DNN.")
.SetDefault(false);
AddComment(R"DOC(
MPC mul Operator.
)DOC");
}
}
};
class MpcMulOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>
GetInputOutputWithSameType() const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
class MpcMulGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true,
"Input(Out@GRAD) should not be null.");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, "Input(Y) should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(out_grad_name), true,
"Input(Out@GRAD) should not be null.");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
}
};
template <typename T>
class MpcMulOpGradMaker : public framework::SingleGradOpDescMaker {
class MpcMulOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
retv->SetType("mpc_mul_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Y", this->Input("Y"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetAttrMap(this->Attrs());
return retv;
}
void Apply(GradOpPtr<T> grad) const override {
grad->SetType("mpc_mul_grad");
grad->SetInput("X", this->Input("X"));
grad->SetInput("Y", this->Input("Y"));
grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
grad->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_mul, ops::MpcMulOp, ops::MpcMulOpMaker,
ops::MpcMulOpInferVarType,
ops::MpcMulOpGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_mul, ops::MpcMulOp,
ops::MpcMulOpMaker,
ops::MpcMulOpInferVarType,
ops::MpcMulOpGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_mul_grad, ops::MpcMulGradOp);
REGISTER_OP_CPU_KERNEL(
mpc_mul, ops::MpcMulKernel<paddle::platform::CPUDeviceContext, int64_t>);
mpc_mul,
ops::MpcMulKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
mpc_mul_grad,
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle {
namespace operators {
......@@ -24,185 +23,170 @@ using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class MpcMulKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Input<Tensor>("Y");
auto *out = ctx.Output<Tensor>("Out");
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
auto x_dims = x->dims();
auto y_dims = y->dims();
int x_mat_width = 1;
int x_mat_height = 1;
int y_mat_width = 1;
int y_mat_height = 1;
for (size_t i = 1; i < x_dims.size(); i++) {
if (i <= x_num_col_dims) {
x_mat_width *= x_dims[i];
} else {
x_mat_height *= x_dims[i];
}
}
for (size_t i = 1; i < y_dims.size(); i++) {
if (i <= y_num_col_dims) {
x_mat_width *= y_dims[i];
} else {
y_mat_height *= y_dims[i];
}
}
Tensor x_matrix;
Tensor y_matrix;
x_matrix.ShareDataWith(*x);
y_matrix.ShareDataWith(*y);
if (x_dims.size() > 3) {
x_matrix.Resize({2, x_mat_width, x_mat_height});
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Input<Tensor>("Y");
auto *out = ctx.Output<Tensor>("Out");
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
auto x_dims = x->dims();
auto y_dims = y->dims();
int x_mat_width = 1;
int x_mat_height = 1;
int y_mat_width = 1;
int y_mat_height = 1;
for (size_t i = 1; i < x_dims.size(); i++) {
if (i <= x_num_col_dims) {
x_mat_width *= x_dims[i];
} else {
x_mat_height *= x_dims[i];
}
}
for (size_t i = 1; i < y_dims.size(); i++) {
if (i <= y_num_col_dims) {
y_mat_width *= y_dims[i];
} else {
y_mat_height *= y_dims[i];
}
}
Tensor x_matrix;
Tensor y_matrix;
x_matrix.ShareDataWith(*x);
y_matrix.ShareDataWith(*y);
x_matrix.Resize({2, x_mat_width, x_mat_height});
y_matrix.Resize({2, y_mat_width, y_mat_height});
out->mutable_data<T>(ctx.GetPlace());
auto out_dim = out->dims();
if (out_dim.size() > 3) {
out->Resize({2, x_mat_width, y_mat_height});
}
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
&x_matrix, &y_matrix, out);
if (out_dim.size() > 3) {
out->Resize(out_dim);
}
}
if (y_dims.size() > 3) {
y_matrix.Resize({2, y_mat_width, y_mat_height});
}
out->mutable_data<T>(ctx.GetPlace());
auto out_dim = out->dims();
if (out_dim.size() > 3) {
out->Resize({2, x_mat_width, y_mat_height});
}
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
&x_matrix, &y_matrix, out);
if (out_dim.size() > 3) {
out->Resize(out_dim);
}
}
};
template <typename DeviceContext, typename T>
class MpcMulGradKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto *dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
auto x_dims = x->dims();
auto y_dims = y->dims();
auto dout_dims = dout->dims();
int x_mat_width = 1;
int x_mat_height = 1;
int y_mat_width = 1;
int y_mat_height = 1;
for (size_t i = 1; i < x_dims.size(); i++) {
if (i <= x_num_col_dims) {
x_mat_width *= x_dims[i];
} else {
x_mat_height *= x_dims[i];
}
}
for (size_t i = 1; i < y_dims.size(); i++) {
if (i <= y_num_col_dims) {
y_mat_width *= y_dims[i];
} else {
y_mat_height *= y_dims[i];
}
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto *dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
auto x_dims = x->dims();
auto y_dims = y->dims();
auto dout_dims = dout->dims();
int x_mat_width = 1;
int x_mat_height = 1;
int y_mat_width = 1;
int y_mat_height = 1;
for (size_t i = 1; i < x_dims.size(); i++) {
if (i <= x_num_col_dims) {
x_mat_width *= x_dims[i];
} else {
x_mat_height *= x_dims[i];
}
}
for (size_t i = 1; i < y_dims.size(); i++) {
if (i <= y_num_col_dims) {
y_mat_width *= y_dims[i];
} else {
y_mat_height *= y_dims[i];
}
}
Tensor x_matrix;
Tensor y_matrix;
Tensor dout_matrix;
x_matrix.ShareDataWith(*x);
y_matrix.ShareDataWith(*y);
dout_matrix.ShareDataWith(*dout);
x_matrix.Resize({2, x_mat_width, x_mat_height});
y_matrix.Resize({2, y_mat_width, y_mat_height});
dout_matrix.Resize({2, x_mat_width, y_mat_height});
if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}
Tensor x_matrix_trans;
Tensor y_matrix_trans;
x_matrix_trans.mutable_data<T>(x->dims(), ctx.GetPlace());
y_matrix_trans.mutable_data<T>(y->dims(), ctx.GetPlace());
x_matrix_trans.Resize({2, x_mat_height, x_mat_width});
y_matrix_trans.Resize({2, y_mat_height, y_mat_width});
auto& dev_ctx = ctx.template device_context<DeviceContext>();
const int Rank = 3;
Eigen::array<int, Rank> permute;
permute[0] = 0;
permute[1] = 2;
permute[2] = 1;
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
if (dx->dims().size() > 3) {
dx->Resize({2, x_mat_width, x_mat_height});
}
auto eigen_in = framework::EigenTensor<T, Rank>::From(y_matrix);
auto eigen_out = framework::EigenTensor<T, Rank>::From(y_matrix_trans);
auto* dev = dev_ctx.eigen_device();
eigen_out.device(*dev) = eigen_in.shuffle(permute);
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
&dout_matrix, &y_matrix_trans, dx);
auto dx_dim = dx->dims();
if (dx_dim.size() > 3) {
dx->Resize(dx_dim);
}
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
if (dy->dims().size() > 3) {
dy->Resize({2, y_mat_width, y_mat_height});
}
auto eigen_in = framework::EigenTensor<T, Rank>::From(x_matrix);
auto eigen_out = framework::EigenTensor<T, Rank>::From(x_matrix_trans);
auto* dev = dev_ctx.eigen_device();
eigen_out.device(*dev) = eigen_in.shuffle(permute);
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
&x_matrix_trans, &dout_matrix, dy);
auto dy_dim = dy->dims();
if (dy_dim.size() > 3) {
dy->Resize(dy_dim);
}
}
}
Tensor x_matrix;
Tensor y_matrix;
Tensor dout_matrix;
x_matrix.ShareDataWith(*x);
y_matrix.ShareDataWith(*y);
dout_matrix.ShareDataWith(*dout);
if (x_dims.size() > 3) {
x_matrix.Resize({2, x_mat_width, x_mat_height});
}
if (y_dims.size() > 3) {
y_matrix.Resize({2, y_mat_width, y_mat_height});
}
if (dout_dims.size() > 3) {
dout_matrix.Resize({2, x_mat_width, y_mat_height});
}
if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}
Tensor x_matrix_trans;
Tensor y_matrix_trans;
x_matrix_trans.mutable_data<T>(x->dims(), ctx.GetPlace());
y_matrix_trans.mutable_data<T>(y->dims(), ctx.GetPlace());
if (x_dims.size() >= 3) {
x_matrix_trans.Resize({2, x_mat_height, x_mat_width});
}
if (y_dims.size() >= 3) {
y_matrix_trans.Resize({2, y_mat_height, y_mat_width});
}
auto &dev_ctx = ctx.template device_context<DeviceContext>();
const int Rank = 3;
Eigen::array<int, Rank> permute;
permute[0] = 0;
permute[1] = 2;
permute[2] = 1;
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
if (dx->dims().size() > 3) {
dx->Resize({2, x_mat_width, x_mat_height});
}
auto eigen_in = framework::EigenTensor<T, Rank>::From(y_matrix);
auto eigen_out = framework::EigenTensor<T, Rank>::From(y_matrix_trans);
auto *dev = dev_ctx.eigen_device();
eigen_out.device(*dev) = eigen_in.shuffle(permute);
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
&dout_matrix, &y_matrix_trans, dx);
auto dx_dim = dx->dims();
if (dx_dim.size() > 3) {
dx->Resize(dx_dim);
}
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
if (dy->dims().size() > 3) {
dy->Resize({2, y_mat_width, y_mat_height});
}
auto eigen_in = framework::EigenTensor<T, Rank>::From(x_matrix);
auto eigen_out = framework::EigenTensor<T, Rank>::From(x_matrix_trans);
auto *dev = dev_ctx.eigen_device();
eigen_out.device(*dev) = eigen_in.shuffle(permute);
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
&x_matrix_trans, &dout_matrix, dy);
auto dy_dim = dy->dims();
if (dy_dim.size() > 3) {
dy->Resize(dy_dim);
}
}
}
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Description:
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "core/privc3/circuit_context.h"
namespace paddle {
namespace operators {
template <typename T> class MpcOpKernel : public framework::OpKernelBase {
template <typename T>
class MpcOpKernel : public framework::OpKernelBase {
public:
using ELEMENT_TYPE = T;
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_instance()->mpc_protocol(),
"Mpc protocol is not yet initialized in executor");
std::shared_ptr<aby3::CircuitContext> mpc_ctx(
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context());
mpc::ContextHolder::template run_with_context<>(&ctx, mpc_ctx,
[&] { ComputeImpl(ctx); });
}
virtual void ComputeImpl(const framework::ExecutionContext &ctx) const = 0;
using ELEMENT_TYPE = T;
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_instance()->mpc_protocol(),
"Mpc protocol is not yet initialized in executor");
std::shared_ptr<aby3::CircuitContext> mpc_ctx(mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context());
mpc::ContextHolder::template run_with_context<>(&ctx, mpc_ctx,
[&] { ComputeImpl(ctx); });
}
virtual void ComputeImpl(const framework::ExecutionContext& ctx) const = 0;
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
......@@ -18,20 +18,20 @@
namespace paddle {
namespace operators {
// forward op defination
//forward op defination
class MpcReluOp : public framework::OperatorWithKernel {
public:
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Y", in_dims);
}
};
// forward input & output defination
//forward input & output defination
class MpcReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
public:
void Make() override {
AddInput("X", "The input tensor.");
AddOutput("Y", "Output of relu_op");
......@@ -41,43 +41,46 @@ Mpc Relu Operator.
}
};
// backward op defination
//backward op defination
class MpcReluGradOp : public framework::OperatorWithKernel {
public:
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim(framework::GradVarName("Y"));
ctx->SetOutputDim(framework::GradVarName("X"), in_dims);
}
};
// backward type, input & output defination
//backward type, input & output defination
template <typename T>
class MpcReluGradMaker : public framework::SingleGradOpDescMaker {
class MpcReluGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override {
auto *op = new T();
op->SetType("mpc_relu_grad");
op->SetInput("Y", this->Output("Y"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
return std::unique_ptr<T>(op);
}
protected:
void Apply(GradOpPtr<T> grad) const override {
grad->SetType("mpc_relu_grad");
grad->SetInput("Y", this->Output("Y"));
grad->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
grad->SetAttrMap(this->Attrs());
grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(mpc_relu, ops::MpcReluOp, ops::MpcReluOpMaker,
REGISTER_OPERATOR(mpc_relu,
ops::MpcReluOp,
ops::MpcReluOpMaker,
ops::MpcReluGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_relu_grad, ops::MpcReluGradOp);
REGISTER_OP_CPU_KERNEL(mpc_relu, ops::MpcReluKernel<CPU, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_relu_grad, ops::MpcReluGradKernel<CPU, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_relu,
ops::MpcReluKernel<CPU, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_relu_grad,
ops::MpcReluGradKernel<CPU, int64_t>);
......@@ -14,43 +14,37 @@
#pragma once
#include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
// Define forward computation
//Define forward computation
template <typename DeviceContext, typename T>
class MpcReluKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
const Tensor *in_t = ctx.Input<Tensor>("X");
Tensor *out_t = ctx.Output<Tensor>("Y");
auto x = in_t->data<T>();
auto y = out_t->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol,
"Protocol %s is not yet created in MPC Protocol.");
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->relu(
in_t, out_t);
void ComputeImpl(const framework::ExecutionContext& ctx) const override {
const Tensor* in_t = ctx.Input<Tensor>("X");
Tensor* out_t = ctx.Output<Tensor>("Y");
auto x = in_t->data<T>();
auto y = out_t->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, "Protocol %s is not yet created in MPC Protocol.");
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->relu(in_t,out_t);
}
};
// Define backward computation
//Define backward computation
template <typename DeviceContext, typename T>
class MpcReluGradKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *dy_t = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto *y_t = ctx.Input<Tensor>("Y");
auto *dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dx = dx_t->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()
->mpc_protocol()
->mpc_operators()
->relu_grad(y_t, dy_t, dx_t, 0.0);
}
void ComputeImpl(const framework::ExecutionContext& ctx) const override {
auto* dy_t = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* y_t = ctx.Input<Tensor>("Y");
auto* dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dx = dx_t->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->relu_grad(y_t, dy_t, dx_t, 0.0);
}
};
} // namespace operaters
} // namespace paddle
}// namespace operaters
}// namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mpc_sgd_op.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -20,77 +20,77 @@ namespace operators {
class MpcSGDOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of MPCSGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of MPCSGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of MPCSGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of MPCSGDOp should not be null.");
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element");
auto param_dim = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
platform::errors::InvalidArgument(
"MPCSGD Operator's input Param and Grad dimensions do not match. "
"The Param %s shape is [%s], but the Grad %s shape is [%s].",
ctx->Inputs("Param")[0], param_dim, ctx->Inputs("Grad")[0],
ctx->GetInputDim("Grad")));
}
ctx->SetOutputDim("ParamOut", param_dim);
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of MPCSGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of MPCSGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of MPCSGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of MPCSGDOp should not be null.");
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
"Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 element");
auto param_dim = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
platform::errors::InvalidArgument(
"MPCSGD Operator's input Param and Grad dimensions do not match. "
"The Param %s shape is [%s], but the Grad %s shape is [%s].",
ctx->Inputs("Param")[0], param_dim, ctx->Inputs("Grad")[0],
ctx->GetInputDim("Grad")));
}
ctx->SetOutputDim("ParamOut", param_dim);
}
protected:
framework::OpKernelType
GetExpectedKernelType(const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(data_type, ctx.device_context());
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class MpcSGDOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto &input_var_n = ctx->Input("Param")[0];
auto in_var_type = ctx->GetType(input_var_n);
PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR,
"The input Var's type should be LoDtensor or SelectedRows,"
" but the received var(%s)'s type is %s",
input_var_n, in_var_type);
for (auto &out_var_n : ctx->Output("ParamOut")) {
if (ctx->GetType(out_var_n) != in_var_type) {
ctx->SetType(out_var_n, in_var_type);
}
}
void operator()(framework::InferVarTypeContext *ctx) const override {
auto in_var_type = ctx->GetInputType("Param");
PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR,
"The input Var's type should be LoDtensor or SelectedRows,"
" but the received var(%s)'s type is %s",
ctx->InputVarName("Param"), in_var_type);
ctx->SetOutputType("ParamOut", in_var_type);
//for (auto &out_var_n : framework::StaticGraphVarTypeInference::Output(ctx, "ParamOut")) {
// if (ctx->GetVarType(out_var_n) != in_var_type) {
// ctx->SetType(out_var_n, in_var_type);
//}
//}
}
};
class MpcSGDOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param", "(Tensor or SelectedRows) Input parameter");
AddInput("LearningRate", "(Tensor) Learning rate of MPCSGD");
AddInput("Grad", "(Tensor or SelectedRows) Input gradient");
AddOutput("ParamOut",
"(Tensor or SelectedRows, same with Param) "
"Output parameter, should share the same memory with Param");
AddComment(R"DOC(
void Make() override {
AddInput("Param", "(Tensor or SelectedRows) Input parameter");
AddInput("LearningRate", "(Tensor) Learning rate of MPCSGD");
AddInput("Grad", "(Tensor or SelectedRows) Input gradient");
AddOutput("ParamOut",
"(Tensor or SelectedRows, same with Param) "
"Output parameter, should share the same memory with Param");
AddComment(R"DOC(
MPCSGD operator
......@@ -102,13 +102,13 @@ $$param\_out = param - learning\_rate * grad$$
}
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
mpc_sgd, ops::MpcSGDOp, ops::MpcSGDOpMaker,
// paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
ops::MpcSGDOpInferVarType);
REGISTER_OP_CPU_KERNEL(
mpc_sgd, ops::MpcSGDOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
mpc_sgd,
ops::MpcSGDOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "mpc_op.h"
#include "paddle/fluid/framework/eigen.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class MpcSGDOpKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
const auto *param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(),
framework::ToTypeName(param_var->Type()));
const auto *grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true,
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(),
framework::ToTypeName(grad_var->Type()));
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
const auto *param = ctx.Input<framework::Tensor>("Param");
const auto *grad = ctx.Input<framework::Tensor>("Grad");
auto *param_out = ctx.Output<framework::Tensor>("ParamOut");
auto sz = param_out->numel();
PADDLE_ENFORCE_EQ(param->numel(), sz);
PADDLE_ENFORCE_EQ(grad->numel(), sz);
const double *lr = learning_rate->data<double>();
// const T *param_data = param->data<T>();
// const T *grad_data = grad->data<T>();
T *out_data = param_out->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol,
"Protocol %s is not yet created in MPC Protocol.");
// update parameters
framework::Tensor temp;
temp.mutable_data<T>(param->dims(), ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(
grad, lr[0], &temp);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(
param, &temp, param_out);
}
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override{
const auto *param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type()));
const auto *grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true,
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.InputNames("Grad").front(),
framework::ToTypeName(grad_var->Type()));
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
const auto *param = ctx.Input<framework::Tensor>("Param");
const auto *grad = ctx.Input<framework::Tensor>("Grad");
auto *param_out = ctx.Output<framework::Tensor>("ParamOut");
auto sz = param_out->numel();
PADDLE_ENFORCE_EQ(param->numel(), sz);
PADDLE_ENFORCE_EQ(grad->numel(), sz);
const double *lr = learning_rate->data<double>();
param_out->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, "Protocol %s is not yet created in MPC Protocol.");
// update parameters
framework::Tensor temp;
temp.mutable_data<T>(param->dims(), ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(grad, lr[0], &temp);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(param, &temp, param_out);
}
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
......@@ -117,21 +117,19 @@ MpcSigmoidCrossEntropyWithLogits Operator.
};
template <typename T>
class MpcSigmoidCrossEntropyWithLogitsGradOpMaker : public framework::SingleGradOpDescMaker {
class MpcSigmoidCrossEntropyWithLogitsGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
retv->SetType("mpc_sigmoid_cross_entropy_with_logits_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Label", this->Input("Label"));
retv->SetInput("Out", this->Output("Out"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
return retv;
void Apply(GradOpPtr<T> grad) const override {
grad->SetType("mpc_sigmoid_cross_entropy_with_logits_grad");
grad->SetInput("X", this->Input("X"));
grad->SetInput("Label", this->Input("Label"));
grad->SetInput("Out", this->Output("Out"));
grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad->SetAttrMap(this->Attrs());
}
};
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mpc_square_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_square_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class MpcSquareOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of MpcSquareOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of MpcSquareOp should not be null."));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of MpcSquareOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of MpcSquareOp should not be null."));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class MpcSquareOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The first input tensor of mpc square op.");
AddOutput("Out", "(Tensor), The output tensor of mpc square op.");
AddComment(R"DOC(
void Make() override {
AddInput("X", "(Tensor), The first input tensor of mpc square op.");
AddOutput("Out", "(Tensor), The output tensor of mpc square op.");
AddComment(R"DOC(
MPC square Operator..
)DOC");
}
}
};
class MpcSquareGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->ShareDim(framework::GradVarName("Out"), framework::GradVarName("X"));
ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X"));
}
using framework::OperatorWithKernel::OperatorWithKernel;
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->ShareDim(framework::GradVarName("Out"), framework::GradVarName("X"));
ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X"));
}
};
template <typename T>
class MpcSquareGradOpMaker : public framework::SingleGradOpDescMaker {
class MpcSquareGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
retv->SetType("mpc_square_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
return retv;
}
void Apply(GradOpPtr<T> grad) const override {
grad->SetType("mpc_square_grad");
grad->SetInput("X", this->Input("X"));
grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_square, ops::MpcSquareOp, ops::MpcSquareOpMaker,
ops::MpcSquareGradOpMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_square, ops::MpcSquareOp,
ops::MpcSquareOpMaker,
ops::MpcSquareGradOpMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_square_grad, ops::MpcSquareGradOp);
REGISTER_OPERATOR(mpc_square_grad, ops::MpcSquareGradOp);
REGISTER_OP_CPU_KERNEL(
mpc_square,
mpc_square,
ops::MpcSquareKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
mpc_square_grad,
mpc_square_grad,
ops::MpcSquareGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle {
namespace operators {
......@@ -23,33 +23,31 @@ using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class MpcSquareKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<Tensor>("X");
auto *out_t = ctx.Output<Tensor>("Out");
out_t->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul(
in_x_t, in_x_t, out_t);
}
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<Tensor>("X");
auto *out_t = ctx.Output<Tensor>("Out");
out_t->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul(in_x_t, in_x_t, out_t);
}
};
template <typename DeviceContext, typename T>
class MpcSquareGradKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<Tensor>("X");
auto *dout_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
if (dx_t != nullptr) {
// allocate memory on device.
dx_t->mutable_data<T>(ctx.GetPlace());
// dx = dout * 2 * x
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(
in_x_t, 2.0, dx_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul(
dx_t, dout_t, dx_t);
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<Tensor>("X");
auto *dout_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
if (dx_t != nullptr) {
// allocate memory on device.
dx_t->mutable_data<T>(ctx.GetPlace());
// dx = dout * 2 * x
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(in_x_t, 2.0, dx_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul(dx_t, dout_t, dx_t);
}
}
}
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "mpc_sum_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_sum_op.h"
namespace paddle {
namespace operators {
......@@ -29,131 +28,135 @@ using Tensor = framework::Tensor;
class MpcSumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInputs("X"), true,
platform::errors::NotFound(
"Input(X) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of MulOp should not be null."));
auto x_var_types = ctx->GetInputsVarType("X");
auto x_dims = ctx->GetInputsDim("X");
auto N = x_dims.size();
PADDLE_ENFORCE_GT(
N, 0, "ShapeError: The input tensor X's dimensions of SumOp "
"should be larger than 0. But received X's dimensions %d, "
"X's shape = [%s].",
N, &x_dims);
if (N == 1) {
VLOG(3) << "Warning: SumOp have only one input, may waste memory";
}
framework::DDim in_dim({0});
for (size_t i = 0; i < x_dims.size(); ++i) {
auto &x_dim = x_dims[i];
// x_dim.size() == 1 means the real dim of selected rows is [0]
if (x_var_types[i] == framework::proto::VarType::SELECTED_ROWS &&
x_dim.size() == 1) {
continue;
}
if (framework::product(x_dim) == 0) {
continue;
}
if (framework::product(in_dim) == 0) {
in_dim = x_dim;
} else {
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
in_dim, x_dim,
"ShapeError: The input tensor X of SumOp must have same shape."
"But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
in_dim, i, x_dim);
} else {
PADDLE_ENFORCE_EQ(
in_dim.size(), x_dim.size(),
"ShapeError: The input tensor X of SumOp must have same "
"dimensions. But received X[0]'s dimensions = %d, X[0]'s shape = "
"[%s], X[%d]'s dimensions = %d, X[%d]'s shape = [%s].",
in_dim.size(), in_dim, i, x_dim.size(), i, x_dim);
// if in_dim or x_dim has -1, not check equal
for (int j = 0; j < x_dim.size(); ++j) {
if (x_dim[j] == -1 || in_dim[j] == -1) {
continue;
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInputs("X"), true,
platform::errors::NotFound("Input(X) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of MulOp should not be null."));
auto x_var_types = ctx->GetInputsVarType("X");
auto x_dims = ctx->GetInputsDim("X");
auto N = x_dims.size();
PADDLE_ENFORCE_GT(
N, 0,
"ShapeError: The input tensor X's dimensions of SumOp "
"should be larger than 0. But received X's dimensions %d, "
"X's shape = [%s].",
N, &x_dims);
if (N == 1) {
VLOG(3) << "Warning: SumOp have only one input, may waste memory";
}
framework::DDim in_dim({0});
for (size_t i = 0; i < x_dims.size(); ++i) {
auto& x_dim = x_dims[i];
// x_dim.size() == 1 means the real dim of selected rows is [0]
if (x_var_types[i] == framework::proto::VarType::SELECTED_ROWS &&
x_dim.size() == 1) {
continue;
}
PADDLE_ENFORCE_EQ(
in_dim[j], x_dim[j],
"ShapeError: The input tensor X of SumOp must have same shape "
"if not -1."
"But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
in_dim, i, x_dim);
}
if (framework::product(x_dim) == 0) {
continue;
}
if (framework::product(in_dim) == 0) {
in_dim = x_dim;
} else {
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
in_dim, x_dim,
"ShapeError: The input tensor X of SumOp must have same shape."
"But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
in_dim, i, x_dim);
} else {
PADDLE_ENFORCE_EQ(
in_dim.size(), x_dim.size(),
"ShapeError: The input tensor X of SumOp must have same "
"dimensions. But received X[0]'s dimensions = %d, X[0]'s shape = "
"[%s], X[%d]'s dimensions = %d, X[%d]'s shape = [%s].",
in_dim.size(), in_dim, i, x_dim.size(), i, x_dim);
// if in_dim or x_dim has -1, not check equal
for (int j = 0; j < x_dim.size(); ++j) {
if (x_dim[j] == -1 || in_dim[j] == -1) {
continue;
}
PADDLE_ENFORCE_EQ(
in_dim[j], x_dim[j],
"ShapeError: The input tensor X of SumOp must have same shape "
"if not -1."
"But received X[0]'s shape = [%s], X[%d]'s shape = [%s].",
in_dim, i, x_dim);
}
}
}
}
}
ctx->SetOutputDim("Out", in_dim);
ctx->ShareLoD("X", /*->*/ "Out");
}
ctx->SetOutputDim("Out", in_dim);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class MpcSumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
void Make() override {
AddInput("X",
"A Varaible list. The shape and data type of the list elements"
"should be consistent. Variable can be multi-dimensional Tensor"
"or LoDTensor, and data types can be: float32, float64, int32, "
"int64.")
.AsDuplicable();
AddOutput("Out", "the sum of input :code:`x`. its shape and data types are "
"consistent with :code:`x`.");
AddAttr<bool>("use_mkldnn",
.AsDuplicable();
AddOutput("Out",
"the sum of input :code:`x`. its shape and data types are "
"consistent with :code:`x`.");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(This OP is used to sum one or more Tensor or LoDTensor
AddComment(R"DOC(This OP is used to sum one or more Tensor or LoDTensor
of the input. If the input is LoDTensor, the output only
shares LoD information with the first input.)DOC");
}
}
};
class MpcSumGradMaker : public framework::GradOpDescMakerBase {
public:
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
auto x_grads = InputGrad("X", false);
std::vector<std::unique_ptr<framework::OpDesc>> grad_ops;
grad_ops.reserve(x_grads.size());
auto og = OutputGrad("Out");
std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
[&og](const std::string &x_grad) {
auto *grad_op = new framework::OpDesc();
grad_op->SetType("scale");
grad_op->SetInput("X", og);
grad_op->SetOutput("Out", {x_grad});
grad_op->SetAttr("scale", 1.0f);
return std::unique_ptr<framework::OpDesc>(grad_op);
});
return grad_ops;
}
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
auto x_grads = InputGrad("X", false);
std::vector<std::unique_ptr<framework::OpDesc>> grad_ops;
grad_ops.reserve(x_grads.size());
auto og = OutputGrad("Out");
std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
[&og](const std::string& x_grad) {
auto* grad_op = new framework::OpDesc();
grad_op->SetType("scale");
grad_op->SetInput("X", og);
grad_op->SetOutput("Out", {x_grad});
grad_op->SetAttr("scale", 1.0f);
return std::unique_ptr<framework::OpDesc>(grad_op);
});
return grad_ops;
}
};
DECLARE_INPLACE_OP_INFERER(MpcSumInplace, {"X", "Out"});
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
// REGISTER_OP_WITHOUT_GRADIENT(mpc_sum, ops::MpcSumOp, ops::MpcSumOpMaker);
REGISTER_OPERATOR(mpc_sum, ops::MpcSumOp, ops::MpcSumOpMaker,
ops::MpcSumGradMaker, ops::MpcSumInplace);
//REGISTER_OP_WITHOUT_GRADIENT(mpc_sum, ops::MpcSumOp, ops::MpcSumOpMaker);
REGISTER_OPERATOR(mpc_sum, ops::MpcSumOp,
ops::MpcSumOpMaker,
ops::MpcSumGradMaker,
ops::MpcSumInplace);
REGISTER_OP_CPU_KERNEL(
mpc_sum, ops::MpcSumKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_sum, ops::MpcSumKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "mpc_op.h"
......@@ -23,62 +23,57 @@ using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class MpcSumKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto in_vars = ctx.MultiInputVar("X");
size_t in_num = in_vars.size();
auto out_var = ctx.OutputVar("Out");
bool in_place = out_var == in_vars[0];
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto in_vars = ctx.MultiInputVar("X");
size_t in_num = in_vars.size();
auto out_var = ctx.OutputVar("Out");
bool in_place = out_var == in_vars[0];
if (out_var->IsType<framework::LoDTensor>()) {
auto *out = out_var->GetMutable<framework::LoDTensor>();
auto *out_ptr = out->mutable_data<T>(ctx.GetPlace());
if (in_num >= 1 && in_vars[0]->IsType<framework::LoDTensor>()) {
auto &in_0_tensor = in_vars[0]->Get<framework::LoDTensor>();
if (in_0_tensor.numel() > 0) {
in_place = (in_0_tensor.data<T>() == out_ptr);
}
}
int start = in_place ? 1 : 0;
if (!in_place) {
if ((in_num >= 2) && in_vars[0]->IsType<framework::LoDTensor>() &&
in_vars[1]->IsType<framework::LoDTensor>()) {
auto &in_0 = in_vars[0]->Get<framework::LoDTensor>();
auto &in_1 = in_vars[1]->Get<framework::LoDTensor>();
if (in_0.numel() && in_1.numel()) {
mpc::MpcInstance::mpc_instance()
->mpc_protocol()
->mpc_operators()
->add(&in_0, &in_1, out);
start = 2;
}
}
if (start != 2) {
auto t = framework::EigenVector<T>::Flatten(*out);
auto &device_ctx = ctx.template device_context<DeviceContext>();
t.device(*device_ctx.eigen_device()) = t.constant(static_cast<T>(0));
}
}
if (out_var->IsType<framework::LoDTensor>()) {
auto *out = out_var->GetMutable<framework::LoDTensor>();
auto *out_ptr = out->mutable_data<T>(ctx.GetPlace());
if (in_num >= 1 && in_vars[0]->IsType<framework::LoDTensor>()) {
auto &in_0_tensor = in_vars[0]->Get<framework::LoDTensor>();
if (in_0_tensor.numel() > 0) {
in_place = (in_0_tensor.data<T>() == out_ptr);
}
}
int start = in_place ? 1 : 0;
if (!in_place) {
if ((in_num >= 2) && in_vars[0]->IsType<framework::LoDTensor>() &&
in_vars[1]->IsType<framework::LoDTensor>()) {
auto &in_0 = in_vars[0]->Get<framework::LoDTensor>();
auto &in_1 = in_vars[1]->Get<framework::LoDTensor>();
if (in_0.numel() && in_1.numel()) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(&in_0, &in_1, out);
start = 2;
}
}
if (start != 2) {
auto t = framework::EigenVector<T>::Flatten(*out);
auto &device_ctx = ctx.template device_context<DeviceContext>();
t.device(*device_ctx.eigen_device()) = t.constant(static_cast<T>(0));
}
}
// If in_place, just skip the first tensor
for (size_t i = start; i < in_num; i++) {
if (in_vars[i]->IsType<framework::LoDTensor>()) {
auto &in_t = in_vars[i]->Get<framework::LoDTensor>();
if (in_t.numel() == 0) {
continue;
}
mpc::MpcInstance::mpc_instance()
->mpc_protocol()
->mpc_operators()
->add(out, &in_t, out);
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
// If in_place, just skip the first tensor
for (size_t i = start; i < in_num; i++) {
if (in_vars[i]->IsType<framework::LoDTensor>()) {
auto &in_t = in_vars[i]->Get<framework::LoDTensor>();
if (in_t.numel() == 0) {
continue;
}
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(out, &in_t, out);
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
}
}else {
PADDLE_THROW("Unexpected branch, output variable type is %s",
framework::ToTypeName(out_var->Type()));
}
}
} else {
PADDLE_THROW("Unexpected branch, output variable type is %s",
framework::ToTypeName(out_var->Type()));
}
}
};
} // namespace operators
} // namespace paddle
} // namespace operators
} // namespace paddle
此差异已折叠。
......@@ -18,8 +18,8 @@
#include "gtest/gtest.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/framework/tensor.h"
namespace aby3 {
......@@ -28,271 +28,274 @@ using paddle::framework::Tensor;
class PaddleTensorTest : public ::testing::Test {
public:
std::shared_ptr<TensorAdapterFactory> _tensor_factory;
CPUDeviceContext _cpu_ctx;
std::shared_ptr<TensorAdapterFactory> _tensor_factory;
CPUDeviceContext _cpu_ctx;
void SetUp() {
_tensor_factory = std::make_shared<PaddleTensorFactory>(&_cpu_ctx);
}
virtual ~PaddleTensorTest() noexcept {}
void SetUp() {
_tensor_factory = std::make_shared<PaddleTensorFactory>(&_cpu_ctx);
}
};
TEST_F(PaddleTensorTest, factory_test) {
EXPECT_NO_THROW(_tensor_factory->template create<int64_t>());
std::vector<size_t> shape = {2, 3};
EXPECT_NO_THROW(_tensor_factory->template create<int64_t>(shape));
EXPECT_NO_THROW(_tensor_factory->template create<int64_t>());
std::vector<size_t> shape = { 2, 3 };
EXPECT_NO_THROW(_tensor_factory->template create<int64_t>(shape));
}
TEST_F(PaddleTensorTest, ctor_test) {
Tensor t;
// t holds no memory
EXPECT_THROW({ PaddleTensor<int64_t> pt(&_cpu_ctx, t); },
::paddle::platform::EnforceNotMet);
t.template mutable_data<int64_t>(_cpu_ctx.GetPlace());
EXPECT_NO_THROW({ PaddleTensor<int64_t> pt(&_cpu_ctx, t); });
Tensor t;
// t holds no memory
EXPECT_THROW({ PaddleTensor<int64_t> pt(&_cpu_ctx, t); }, ::paddle::platform::EnforceNotMet);
t.template mutable_data<int64_t>(_cpu_ctx.GetPlace());
EXPECT_NO_THROW({ PaddleTensor<int64_t> pt(&_cpu_ctx, t); });
}
TEST_F(PaddleTensorTest, shape_test) {
std::vector<size_t> shape = {2, 3};
auto pt = _tensor_factory->template create<int64_t>(shape);
std::vector<size_t> shape = { 2, 3 };
auto pt = _tensor_factory->template create<int64_t>(shape);
EXPECT_EQ(shape.size(), pt->shape().size());
EXPECT_EQ(shape.size(), pt->shape().size());
bool eq = std::equal(shape.begin(), shape.end(), pt->shape().begin());
EXPECT_TRUE(eq);
bool eq = std::equal(shape.begin(), shape.end(), pt->shape().begin());
EXPECT_TRUE(eq);
EXPECT_EQ(6u, pt->numel());
EXPECT_EQ(6u, pt->numel());
}
TEST_F(PaddleTensorTest, reshape_test) {
std::vector<size_t> shape = {2, 3};
auto pt = _tensor_factory->template create<int64_t>();
std::vector<size_t> shape = { 2, 3 };
auto pt = _tensor_factory->template create<int64_t>();
pt->reshape(shape);
pt->reshape(shape);
EXPECT_EQ(shape.size(), pt->shape().size());
EXPECT_EQ(shape.size(), pt->shape().size());
bool eq = std::equal(shape.begin(), shape.end(), pt->shape().begin());
EXPECT_TRUE(eq);
bool eq = std::equal(shape.begin(), shape.end(), pt->shape().begin());
EXPECT_TRUE(eq);
}
TEST_F(PaddleTensorTest, add_test) {
std::vector<size_t> shape = {1};
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 1;
pt1->data()[0] = 2;
pt0->add(pt1.get(), pt2.get());
EXPECT_EQ(3, pt2->data()[0]);
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 1;
pt1->data()[0] = 2;
pt0->add(pt1.get(), pt2.get());
EXPECT_EQ(3, pt2->data()[0]);
}
TEST_F(PaddleTensorTest, sub_test) {
std::vector<size_t> shape = {1};
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 2;
pt1->data()[0] = 1;
pt0->sub(pt1.get(), pt2.get());
EXPECT_EQ(1, pt2->data()[0]);
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 2;
pt1->data()[0] = 1;
pt0->sub(pt1.get(), pt2.get());
EXPECT_EQ(1, pt2->data()[0]);
}
TEST_F(PaddleTensorTest, negative_test) {
std::vector<size_t> shape = {1};
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 2;
pt0->negative(pt1.get());
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 2;
pt0->negative(pt1.get());
EXPECT_EQ(-2, pt1->data()[0]);
EXPECT_EQ(-2, pt1->data()[0]);
}
TEST_F(PaddleTensorTest, mul_test) {
std::vector<size_t> shape = {1};
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 7;
pt1->data()[0] = 3;
pt0->mul(pt1.get(), pt2.get());
EXPECT_EQ(21, pt2->data()[0]);
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 7;
pt1->data()[0] = 3;
pt0->mul(pt1.get(), pt2.get());
EXPECT_EQ(21, pt2->data()[0]);
}
TEST_F(PaddleTensorTest, div_test) {
std::vector<size_t> shape = {1};
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 7;
pt1->data()[0] = 3;
pt0->div(pt1.get(), pt2.get());
EXPECT_EQ(2, pt2->data()[0]);
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 7;
pt1->data()[0] = 3;
pt0->div(pt1.get(), pt2.get());
EXPECT_EQ(2, pt2->data()[0]);
}
TEST_F(PaddleTensorTest, matmul_test) {
std::vector<size_t> shape0 = {2, 3};
std::vector<size_t> shape1 = {3, 2};
std::vector<size_t> shape2 = {2, 2};
auto pt0 = _tensor_factory->template create<int64_t>(shape0);
auto pt1 = _tensor_factory->template create<int64_t>(shape1);
auto pt2 = _tensor_factory->template create<int64_t>(shape2);
for (size_t i = 0; i < 6; ++i) {
pt0->data()[i] = i;
pt1->data()[i] = i;
}
pt0->mat_mul(pt1.get(), pt2.get());
// | 0 1 2 | | 0 1 | | 10 13 |
// | 3 4 5 | x | 2 3 | = | 28 40 |
// | 4 5 |
std::vector<int64_t> res = {10, 13, 28, 40};
bool eq = std::equal(res.begin(), res.end(), pt2->data());
EXPECT_TRUE(eq);
std::vector<size_t> shape0 = { 2, 3 };
std::vector<size_t> shape1 = { 3, 2 };
std::vector<size_t> shape2 = { 2, 2 };
auto pt0 = _tensor_factory->template create<int64_t>(shape0);
auto pt1 = _tensor_factory->template create<int64_t>(shape1);
auto pt2 = _tensor_factory->template create<int64_t>(shape2);
for (size_t i = 0; i < 6; ++i) {
pt0->data()[i] = i;
pt1->data()[i] = i;
}
pt0->mat_mul(pt1.get(), pt2.get());
// | 0 1 2 | | 0 1 | | 10 13 |
// | 3 4 5 | x | 2 3 | = | 28 40 |
// | 4 5 |
std::vector<int64_t> res = { 10, 13, 28, 40 };
bool eq = std::equal(res.begin(), res.end(), pt2->data());
EXPECT_TRUE(eq);
}
TEST_F(PaddleTensorTest, xor_test) {
std::vector<size_t> shape = {1};
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 3;
pt1->data()[0] = 7;
pt0->bitwise_xor(pt1.get(), pt2.get());
EXPECT_EQ(4, pt2->data()[0]);
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 3;
pt1->data()[0] = 7;
pt0->bitwise_xor(pt1.get(), pt2.get());
EXPECT_EQ(4, pt2->data()[0]);
}
TEST_F(PaddleTensorTest, and_test) {
std::vector<size_t> shape = {1};
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 3;
pt1->data()[0] = 7;
pt0->bitwise_and(pt1.get(), pt2.get());
EXPECT_EQ(3, pt2->data()[0]);
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 3;
pt1->data()[0] = 7;
pt0->bitwise_and(pt1.get(), pt2.get());
EXPECT_EQ(3, pt2->data()[0]);
}
TEST_F(PaddleTensorTest, or_test) {
std::vector<size_t> shape = {1};
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 3;
pt1->data()[0] = 7;
pt0->bitwise_or(pt1.get(), pt2.get());
EXPECT_EQ(7, pt2->data()[0]);
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 3;
pt1->data()[0] = 7;
pt0->bitwise_or(pt1.get(), pt2.get());
EXPECT_EQ(7, pt2->data()[0]);
}
TEST_F(PaddleTensorTest, not_test) {
std::vector<size_t> shape = {1};
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 0;
pt0->bitwise_not(pt1.get());
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 0;
pt0->bitwise_not(pt1.get());
EXPECT_EQ(-1, pt1->data()[0]);
EXPECT_EQ(-1, pt1->data()[0]);
}
TEST_F(PaddleTensorTest, lshift_test) {
std::vector<size_t> shape = {1};
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 2;
pt0->lshift(1, pt1.get());
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 2;
pt0->lshift(1, pt1.get());
EXPECT_EQ(4, pt1->data()[0]);
EXPECT_EQ(4, pt1->data()[0]);
}
TEST_F(PaddleTensorTest, rshift_test) {
std::vector<size_t> shape = {1};
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 2;
pt0->rshift(1, pt1.get());
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 2;
pt0->rshift(1, pt1.get());
EXPECT_EQ(1, pt1->data()[0]);
EXPECT_EQ(1, pt1->data()[0]);
}
TEST_F(PaddleTensorTest, logical_rshift_test) {
std::vector<size_t> shape = {1};
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = -1;
pt0->logical_rshift(1, pt1.get());
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = -1;
pt0->logical_rshift(1, pt1.get());
EXPECT_EQ(-1ull >> 1, pt1->data()[0]);
EXPECT_EQ(-1ull >> 1, pt1->data()[0]);
}
TEST_F(PaddleTensorTest, scale_test) {
auto pt = _tensor_factory->template create<int64_t>();
auto pt = _tensor_factory->template create<int64_t>();
auto pt_ = dynamic_cast<PaddleTensor<int64_t> *>(pt.get());
auto pt_ = dynamic_cast<PaddleTensor<int64_t>*>(pt.get());
pt_->scaling_factor() = 1;
pt_->scaling_factor() = 1;
Tensor t;
Tensor t;
int dim[1] = {1};
paddle::framework::DDim ddim(dim, 1);
t.template mutable_data<float>(ddim, _cpu_ctx.GetPlace());
int dim[1] = { 1 };
paddle::framework::DDim ddim(dim, 1);
t.template mutable_data<float>(ddim, _cpu_ctx.GetPlace());
t.template data<float>()[0] = 0.25f;
t.template data<float>()[0] = 0.25f;
pt_->template from_float_point_type<float>(t, 2);
pt_->template from_float_point_type<float>(t, 2);
EXPECT_EQ(2, pt_->scaling_factor());
EXPECT_EQ(1, pt->data()[0]);
EXPECT_EQ(2, pt_->scaling_factor());
EXPECT_EQ(1, pt->data()[0]);
}
TEST_F(PaddleTensorTest, scalar_test) {
auto pt = _tensor_factory->template create<int64_t>();
auto pt = _tensor_factory->template create<int64_t>();
auto pt_ = dynamic_cast<PaddleTensor<int64_t> *>(pt.get());
auto pt_ = dynamic_cast<PaddleTensor<int64_t>*>(pt.get());
pt_->scaling_factor() = 1;
pt_->scaling_factor() = 1;
std::vector<size_t> shape = {2};
pt_->template from_float_point_scalar(0.25f, shape, 2);
std::vector<size_t> shape = { 2 };
pt_->template from_float_point_scalar(0.25f, shape, 2);
EXPECT_EQ(2, pt_->scaling_factor());
EXPECT_EQ(1, pt->data()[0]);
EXPECT_EQ(1, pt->data()[1]);
EXPECT_EQ(2, pt_->scaling_factor());
EXPECT_EQ(1, pt->data()[0]);
EXPECT_EQ(1, pt->data()[1]);
}
TEST_F(PaddleTensorTest, slice_test) {
std::vector<size_t> shape = {2, 2};
auto pt = _tensor_factory->template create<int64_t>(shape);
auto ret = _tensor_factory->template create<int64_t>();
std::vector<size_t> shape = { 2, 2 };
auto pt = _tensor_factory->template create<int64_t>(shape);
auto ret = _tensor_factory->template create<int64_t>();
auto pt_ = dynamic_cast<PaddleTensor<int64_t> *>(pt.get());
pt_->scaling_factor() = 1;
auto pt_ = dynamic_cast<PaddleTensor<int64_t>*>(pt.get());
pt_->scaling_factor() = 1;
for (size_t i = 0; i < 4; ++i) {
pt->data()[0] = i;
}
for (size_t i = 0; i < 4; ++i) {
pt->data()[0] = i;
}
pt_->slice(1, 2, ret.get());
pt_->slice(1, 2, ret.get());
auto shape_ = ret->shape();
auto shape_ = ret->shape();
EXPECT_EQ(2, shape_.size());
EXPECT_EQ(1, shape_[0]);
EXPECT_EQ(2, shape_[1]);
EXPECT_EQ(2, shape_.size());
EXPECT_EQ(1, shape_[0]);
EXPECT_EQ(2, shape_[1]);
EXPECT_EQ(1, ret->scaling_factor());
EXPECT_EQ(1, ret->scaling_factor());
EXPECT_EQ(2, ret->data()[0]);
EXPECT_EQ(3, ret->data()[1]);
EXPECT_EQ(2, ret->data()[0]);
EXPECT_EQ(3, ret->data()[1]);
}
} // namespace aby3
......@@ -21,14 +21,13 @@ from paddle.fluid import core
from paddle.fluid import unique_name
from paddle.fluid.framework import Variable
from paddle.fluid.framework import convert_np_dtype_to_dtype_
from paddle.fluid.data_feeder import check_type, check_dtype
class MpcVariable(Variable):
"""
Extends from paddle.fluid.framework.Variable and rewrite
the __init__ method where the shape is resized.
"""
def __init__(self,
block,
type=core.VarDesc.VarType.LOD_TENSOR,
......@@ -91,22 +90,22 @@ class MpcVariable(Variable):
else:
old_dtype = self.dtype
if dtype != old_dtype:
raise ValueError(
"MpcVariable {0} has been created before. "
"The previous data type is {1}; the new "
"data type is {2}. They are not "
"matched.".format(self.name, old_dtype, dtype))
raise ValueError("MpcVariable {0} has been created before. "
"The previous data type is {1}; the new "
"data type is {2}. They are not "
"matched.".format(self.name, old_dtype,
dtype))
if lod_level is not None:
if is_new_var:
self.desc.set_lod_level(lod_level)
else:
if lod_level != self.lod_level:
raise ValueError(
"MpcVariable {0} has been created before. "
"The previous lod_level is {1}; the new "
"lod_level is {2}. They are not "
"matched".format(self.name, self.lod_level, lod_level))
raise ValueError("MpcVariable {0} has been created before. "
"The previous lod_level is {1}; the new "
"lod_level is {2}. They are not "
"matched".format(self.name, self.lod_level,
lod_level))
if persistable is not None:
if is_new_var:
self.desc.set_persistable(persistable)
......@@ -156,8 +155,7 @@ class MpcParameter(MpcVariable):
if len(shape) == 0:
raise ValueError(
"The dimensions of shape for MpcParameter must be greater than 0"
)
"The dimensions of shape for MpcParameter must be greater than 0")
for each in shape:
if each < 0:
......@@ -175,8 +173,7 @@ class MpcParameter(MpcVariable):
**kwargs)
self.trainable = kwargs.get('trainable', True)
self.optimize_attr = kwargs.get('optimize_attr',
{'learning_rate': 1.0})
self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})
self.regularizer = kwargs.get('regularizer', None)
......@@ -203,8 +200,8 @@ class MpcParameter(MpcVariable):
additional_attr = ("trainable", "optimize_attr", "regularizer",
"gradient_clip_attr", "do_model_average")
for attr_name in additional_attr:
res_str += "%s: %s\n" % (
attr_name, cpt.to_text(getattr(self, attr_name)))
res_str += "%s: %s\n" % (attr_name,
cpt.to_text(getattr(self, attr_name)))
else:
res_str = MpcVariable.to_string(self, throw_on_error, False)
return res_str
......@@ -245,8 +242,7 @@ def create_mpc_parameter(block, *args, **kwargs):
init_ops_len = len(init_ops)
if init_ops_len > 1:
raise RuntimeError("mpc_param " + mpc_param.name +
" is inited by multiple init ops " + str(
init_ops))
" is inited by multiple init ops " + str(init_ops))
elif init_ops_len == 1:
# TODO(Paddle 1.7): already inited, do nothing, should log a warning
pass
......@@ -272,7 +268,6 @@ def create_mpc_var(block, *args, **kwargs):
kwargs['initializer'](var, block)
return var
def is_mpc_parameter(var):
"""
Check whether the given variable is an instance of MpcParameter.
......@@ -282,4 +277,13 @@ def is_mpc_parameter(var):
bool: True if the given `var` is an instance of Parameter,
False if not.
"""
return isinstance(var, MpcParameter)
return type(var) == MpcParameter
def check_mpc_variable_and_dtype(input,
input_name,
expected_dtype,
op_name,
extra_message=''):
check_type(input, input_name, MpcVariable, op_name, extra_message)
check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
......@@ -14,9 +14,10 @@
"""
basic mpc op layers.
"""
from paddle.fluid.data_feeder import check_type_and_dtype
from paddle.fluid.data_feeder import check_variable_and_dtype
from ..framework import MpcVariable
from ..framework import check_mpc_variable_and_dtype
from ..mpc_layer_helper import MpcLayerHelper
__all__ = [
......@@ -32,8 +33,8 @@ def _elementwise_op(helper):
assert x is not None, 'x cannot be None in {}'.format(op_type)
assert y is not None, 'y cannot be None in {}'.format(op_type)
check_type_and_dtype(x, 'x', MpcVariable, ['int64'], op_type)
check_type_and_dtype(y, 'y', MpcVariable, ['int64'], op_type)
check_mpc_variable_and_dtype(x, 'x', ['int64'], op_type)
check_mpc_variable_and_dtype(y, 'y', ['int64'], op_type)
axis = helper.kwargs.get('axis', -1)
use_mkldnn = helper.kwargs.get('use_mkldnn', False)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,7 +14,6 @@
"""
mpc math compare layers.
"""
from paddle.fluid.data_feeder import check_type_and_dtype
from ..framework import MpcVariable
from ..mpc_layer_helper import MpcLayerHelper
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,9 +14,9 @@
"""
mpc math op layers.
"""
from paddle.fluid.data_feeder import check_type_and_dtype
from ..framework import MpcVariable
from ..framework import check_mpc_variable_and_dtype
from ..mpc_layer_helper import MpcLayerHelper
__all__ = [
......@@ -39,7 +39,7 @@ def mean(x, name=None):
Examples: todo
"""
helper = MpcLayerHelper("mean", **locals())
check_type_and_dtype(x, 'x', MpcVariable, ['int64'], 'mean')
check_mpc_variable_and_dtype(x, 'x', ['int64'], 'mean')
if name is None:
out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype)
else:
......@@ -64,7 +64,7 @@ def square(x, name=None):
Examples: todo
"""
helper = MpcLayerHelper("square", **locals())
check_type_and_dtype(x, 'x', MpcVariable, ['int64'], 'square')
check_mpc_variable_and_dtype(x, 'x', ['int64'], 'square')
if name is None:
out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype)
else:
......@@ -89,8 +89,7 @@ def sum(x):
Examples: todo
"""
helper = MpcLayerHelper("sum", **locals())
out = helper.create_mpc_variable_for_type_inference(
dtype=helper.input_dtype('x'))
out = helper.create_mpc_variable_for_type_inference(dtype=helper.input_dtype('x'))
helper.append_op(
type="mpc_sum",
inputs={"X": x},
......@@ -116,18 +115,16 @@ def square_error_cost(input, label):
Examples: todo
"""
helper = MpcLayerHelper('square_error_cost', **locals())
minus_out = helper.create_mpc_variable_for_type_inference(
dtype=input.dtype)
minus_out = helper.create_mpc_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='mpc_elementwise_sub',
inputs={'X': [input],
'Y': [label]},
outputs={'Out': [minus_out]})
square_out = helper.create_mpc_variable_for_type_inference(
dtype=input.dtype)
square_out = helper.create_mpc_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='mpc_square',
type='mpc_square',
inputs={'X': [minus_out]},
outputs={'Out': [square_out]})
return square_out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,12 +14,14 @@
"""
mpc matrix op layers.
"""
from paddle.fluid.data_feeder import check_type_and_dtype
from ..framework import MpcVariable
from ..framework import check_mpc_variable_and_dtype
from ..mpc_layer_helper import MpcLayerHelper
__all__ = ['mul', ]
__all__ = [
'mul',
]
def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
......@@ -61,13 +63,13 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
inputs = {"X": [x], "Y": [y]}
attrs = {
"x_num_col_dims": x_num_col_dims,
"x_num_col_dims": x_num_col_dims,
"y_num_col_dims": y_num_col_dims
}
helper = MpcLayerHelper("mul", **locals())
check_type_and_dtype(x, 'x', MpcVariable, ['int64'], 'mul')
check_type_and_dtype(y, 'y', MpcVariable, ['int64'], 'mul')
check_mpc_variable_and_dtype(x, 'x', ['int64'], 'mul')
check_mpc_variable_and_dtype(y, 'y', ['int64'], 'mul')
if name is None:
out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype)
else:
......@@ -75,9 +77,9 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type="mpc_mul",
inputs={"X": x,
"Y": y},
attrs=attrs,
type="mpc_mul",
inputs={"X": x,
"Y": y},
attrs=attrs,
outputs={"Out": out})
return out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -17,9 +17,9 @@ mpc ml op layers.
from functools import reduce
from paddle.fluid.data_feeder import check_type, check_dtype
from paddle.fluid.data_feeder import check_type_and_dtype
import numpy
from ..framework import MpcVariable
from ..framework import check_mpc_variable_and_dtype
from ..mpc_layer_helper import MpcLayerHelper
__all__ = [
......@@ -30,9 +30,6 @@ __all__ = [
]
# add softmax, relu
def fc(input,
size,
num_flatten_dims=1,
......@@ -186,8 +183,7 @@ def softmax(input, use_cudnn=False, name=None, axis=-1):
"""
attrs = {"axis": axis, "use_cudnn": use_cudnn}
helper = MpcLayerHelper('softmax', **locals())
check_type_and_dtype(input, 'input', MpcVariable,
['float16', 'float32', 'float64'], 'softmax')
check_mpc_variable_and_dtype(input, 'input', ['int64'], 'softmax')
dtype = helper.input_dtype()
mpc_softmax_out = helper.create_mpc_variable_for_type_inference(dtype)
......@@ -226,7 +222,9 @@ def relu(input, name=None):
dtype = helper.input_dtype(input_param_name='input')
out = helper.create_mpc_variable_for_type_inference(dtype)
helper.append_op(
type="mpc_relu", inputs={"X": input}, outputs={"Y": out})
type="mpc_relu",
inputs={"X": input},
outputs={"Y": out})
return out
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -32,7 +32,6 @@ def monkey_patch_mpc_variable():
Monkey patch for operator overloading.
:return:
"""
def unique_tmp_name():
"""
Generate temp name for variable.
......@@ -80,7 +79,9 @@ def monkey_patch_mpc_variable():
tmp_name = unique_tmp_name()
return block.create_var(name=tmp_name, dtype=dtype)
def _elemwise_method_creator_(method_name, op_type, reverse=False):
def _elemwise_method_creator_(method_name,
op_type,
reverse=False):
"""
Operator overloading for different method.
:param method_name: the name of operator which is overloaded.
......@@ -88,19 +89,16 @@ def monkey_patch_mpc_variable():
:param reverse:
:return:
"""
def __impl__(self, other_var):
lhs_dtype = safe_get_dtype(self)
if method_name in compare_ops:
if not isinstance(other_var, Variable):
raise NotImplementedError(
"Unsupported data type of {} for compare operations."
.format(other_var.name))
raise NotImplementedError("Unsupported data type of {} for compare operations."
.format(other_var.name))
else:
if not isinstance(other_var, MpcVariable):
raise NotImplementedError(
"Unsupported data type of {}.".format(other_var.name))
raise NotImplementedError("Unsupported data type of {}.".format(other_var.name))
rhs_dtype = safe_get_dtype(other_var)
if reverse:
......@@ -111,8 +109,7 @@ def monkey_patch_mpc_variable():
if method_name in compare_ops:
out = create_new_tmp_var(current_block(self), dtype=rhs_dtype)
else:
out = create_new_tmp_mpc_var(
current_block(self), dtype=lhs_dtype)
out = create_new_tmp_mpc_var(current_block(self), dtype=lhs_dtype)
# out = create_new_tmp_mpc_var(current_block(self), dtype=lhs_dtype)
......@@ -120,9 +117,9 @@ def monkey_patch_mpc_variable():
if other_var.shape[0] == -1:
axis = 0
assert len(self.shape) >= len(other_var.shape), (
"The rank of the first argument of an binary operator cannot "
"be smaller than the rank of its second argument: %s vs %s" %
(len(self.shape), len(other_var.shape)))
"The rank of the first argument of an binary operator cannot "
"be smaller than the rank of its second argument: %s vs %s" %
(len(self.shape), len(other_var.shape)))
current_block(self).append_op(
type=op_type,
......@@ -157,32 +154,33 @@ def monkey_patch_mpc_variable():
# inject methods
for method_name, op_type, reverse in (
("__add__", "mpc_elementwise_add", False),
("__add__", "mpc_elementwise_add", False),
# a+b == b+a. Do not need to reverse explicitly
("__radd__", "mpc_elementwise_add", False),
("__sub__", "mpc_elementwise_sub", False),
("__rsub__", "mpc_elementwise_sub", True),
("__mul__", "mpc_elementwise_mul", False),
("__radd__", "mpc_elementwise_add", False),
("__sub__", "mpc_elementwise_sub", False),
("__rsub__", "mpc_elementwise_sub", True),
("__mul__", "mpc_elementwise_mul", False),
# a*b == b*a. Do not need to reverse explicitly
("__rmul__", "mpc_elementwise_mul", False),
("__div__", "mpc_elementwise_div", False),
("__truediv__", "mpc_elementwise_div", False),
("__rdiv__", "mpc_elementwise_div", True),
("__rtruediv__", "mpc_elementwise_div", True),
("__pow__", "mpc_elementwise_pow", False),
("__rpow__", "mpc_elementwise_pow", True),
("__floordiv__", "mpc_elementwise_floordiv", False),
("__mod__", "mpc_elementwise_mod", False),
("__rmul__", "mpc_elementwise_mul", False),
("__div__", "mpc_elementwise_div", False),
("__truediv__", "mpc_elementwise_div", False),
("__rdiv__", "mpc_elementwise_div", True),
("__rtruediv__", "mpc_elementwise_div", True),
("__pow__", "mpc_elementwise_pow", False),
("__rpow__", "mpc_elementwise_pow", True),
("__floordiv__", "mpc_elementwise_floordiv", False),
("__mod__", "mpc_elementwise_mod", False),
# for logical compare
("__eq__", "mpc_equal", False),
("__ne__", "mpc_not_equal", False),
("__lt__", "mpc_less_than", False),
("__le__", "mpc_less_equal", False),
("__gt__", "mpc_greater_than", False),
("__ge__", "mpc_greater_equal", False)):
("__eq__", "mpc_equal", False),
("__ne__", "mpc_not_equal", False),
("__lt__", "mpc_less_than", False),
("__le__", "mpc_less_equal", False),
("__gt__", "mpc_greater_than", False),
("__ge__", "mpc_greater_equal", False)
):
# Not support computation between MpcVariable and scalar.
setattr(MpcVariable, method_name,
setattr(MpcVariable,
method_name,
_elemwise_method_creator_(method_name, op_type, reverse)
if method_name in supported_mpc_ops else announce_not_impl)
# MpcVariable.astype = astype
......@@ -34,7 +34,7 @@ def python_version():
max_version, mid_version, min_version = python_version()
REQUIRED_PACKAGES = [
'six >= 1.10.0', 'protobuf >= 3.1.0', 'paddlepaddle == 1.6.3', 'paddlepaddle-gpu >= 1.8'
'six >= 1.10.0', 'protobuf >= 3.1.0', 'paddlepaddle == 1.8.0', 'paddlepaddle-gpu >= 1.8'
]
if max_version < 3:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册