未验证 提交 a1a9bf6b 编写于 作者: Q Qinghe JING 提交者: GitHub

Merge pull request #70 from jhjiangcs/smc-611

improve code to support PaddlePaddle1.8.0.
...@@ -34,8 +34,8 @@ execute_process(COMMAND ${PYTHON} -c "import paddle;print(paddle.version.full_ve ...@@ -34,8 +34,8 @@ execute_process(COMMAND ${PYTHON} -c "import paddle;print(paddle.version.full_ve
RESULT_VARIABLE ret OUTPUT_VARIABLE paddle_version OUTPUT_STRIP_TRAILING_WHITESPACE) RESULT_VARIABLE ret OUTPUT_VARIABLE paddle_version OUTPUT_STRIP_TRAILING_WHITESPACE)
if (NOT ret) if (NOT ret)
if (NOT ${paddle_version} STREQUAL "1.6.3") if (NOT ${paddle_version} STREQUAL "1.8.0")
message(FATAL_ERROR "Paddle installation of 1.6.3 is required but ${paddle_version} is found") message(FATAL_ERROR "Paddle installation of 1.8.0 is required but ${paddle_version} is found")
endif() endif()
else() else()
message(FATAL_ERROR "Could not get paddle version.") message(FATAL_ERROR "Could not get paddle version.")
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#include "mpc_compare_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "mpc_compare_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,16 +25,16 @@ class MpcCompareOp : public framework::OperatorWithKernel { ...@@ -25,16 +25,16 @@ class MpcCompareOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx->HasInput("X"), true,
"Input(X) of MpcCompareOp should not be null.")); platform::errors::NotFound("Input(X) of MpcCompareOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx->HasInput("Y"), true,
"Input(Y) of MpcCompareOp should not be null.")); platform::errors::NotFound("Input(Y) of MpcCompareOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx->HasOutput("Out"), true,
"Output(Out) of MpcCompareOp should not be null.")); platform::errors::NotFound("Output(Out) of MpcCompareOp should not be null."));
auto dim_x = ctx->GetInputDim("X"); auto dim_x = ctx->GetInputDim("X");
auto dim_y = ctx->GetInputDim("Y"); auto dim_y = ctx->GetInputDim("Y");
...@@ -45,11 +45,12 @@ public: ...@@ -45,11 +45,12 @@ public:
ctx->ShareLoD("Y", /*->*/ "Out"); ctx->ShareLoD("Y", /*->*/ "Out");
} }
framework::OpKernelType framework::OpKernelType GetExpectedKernelType(
GetExpectedKernelType(const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
class MpcCompareOpMaker : public framework::OpProtoAndCheckerMaker { class MpcCompareOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -68,40 +69,27 @@ MPC Compare Operator. ...@@ -68,40 +69,27 @@ MPC Compare Operator.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_than, ops::MpcCompareOp, REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_than, ops::MpcCompareOp, ops::MpcCompareOpMaker);
ops::MpcCompareOpMaker); REGISTER_OP_CPU_KERNEL(mpc_greater_than,
REGISTER_OP_CPU_KERNEL( ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcGreaterThanFunctor>);
mpc_greater_than,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker);
ops::MpcGreaterThanFunctor>); REGISTER_OP_CPU_KERNEL(mpc_greater_equal,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcGreaterEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_equal, ops::MpcCompareOp,
ops::MpcCompareOpMaker); REGISTER_OP_WITHOUT_GRADIENT(mpc_less_than, ops::MpcCompareOp, ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(mpc_less_than,
mpc_greater_equal, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcLessThanFunctor>);
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::MpcGreaterEqualFunctor>); REGISTER_OP_WITHOUT_GRADIENT(mpc_less_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(mpc_less_equal,
REGISTER_OP_WITHOUT_GRADIENT(mpc_less_than, ops::MpcCompareOp, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcLessEqualFunctor>);
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_WITHOUT_GRADIENT(mpc_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker);
mpc_less_than, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, REGISTER_OP_CPU_KERNEL(mpc_equal,
int64_t, ops::MpcLessThanFunctor>); ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_less_equal, ops::MpcCompareOp, REGISTER_OP_WITHOUT_GRADIENT(mpc_not_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker);
ops::MpcCompareOpMaker); REGISTER_OP_CPU_KERNEL(mpc_not_equal,
REGISTER_OP_CPU_KERNEL( ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcNotEqualFunctor>);
mpc_less_equal, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MpcLessEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_equal, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_equal, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MpcEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_not_equal, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_not_equal, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MpcNotEqualFunctor>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License.uage governing permissions and limitations under the License. */
#pragma once #pragma once
#include "mpc_op.h" #include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include <math.h>
#include <type_traits>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,50 +22,44 @@ using Tensor = framework::Tensor; ...@@ -25,50 +22,44 @@ using Tensor = framework::Tensor;
struct MpcGreaterThanFunctor { struct MpcGreaterThanFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->gt( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->gt(in_x_t, in_y_t, out_t);
in_x_t, in_y_t, out_t);
} }
}; };
struct MpcGreaterEqualFunctor { struct MpcGreaterEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->geq( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->geq(in_x_t, in_y_t, out_t);
in_x_t, in_y_t, out_t);
} }
}; };
struct MpcLessThanFunctor { struct MpcLessThanFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->lt( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->lt(in_x_t, in_y_t, out_t);
in_x_t, in_y_t, out_t);
} }
}; };
struct MpcLessEqualFunctor { struct MpcLessEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->leq( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->leq(in_x_t, in_y_t, out_t);
in_x_t, in_y_t, out_t);
} }
}; };
struct MpcEqualFunctor { struct MpcEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->eq( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->eq(in_x_t, in_y_t, out_t);
in_x_t, in_y_t, out_t);
} }
}; };
struct MpcNotEqualFunctor { struct MpcNotEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) { void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neq( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neq(in_x_t, in_y_t, out_t);
in_x_t, in_y_t, out_t);
} }
}; };
template <typename DeviceContext, typename T, typename Functor> template <typename DeviceContext, typename T, typename Functor>
class MpcCompareOpKernel : public MpcOpKernel<T> { class MpcCompareOpKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override{
auto *in_x_t = ctx.Input<framework::LoDTensor>("X"); auto *in_x_t = ctx.Input<framework::LoDTensor>("X");
auto *in_y_t = ctx.Input<framework::LoDTensor>("Y"); auto *in_y_t = ctx.Input<framework::LoDTensor>("Y");
auto *out_t = ctx.Output<framework::LoDTensor>("Out"); auto *out_t = ctx.Output<framework::LoDTensor>("Out");
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#include "mpc_elementwise_add_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "mpc_elementwise_add_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -24,38 +24,33 @@ class MpcElementwiseAddOp : public framework::OperatorWithKernel { ...@@ -24,38 +24,33 @@ class MpcElementwiseAddOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true, ctx->HasInput("X"), true,
platform::errors::NotFound( platform::errors::NotFound("Input(X) of MpcElementwiseAddOp should not be null."));
"Input(X) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true, ctx->HasInput("Y"), true,
platform::errors::NotFound( platform::errors::NotFound("Input(Y) of MpcElementwiseAddOp should not be null."));
"Input(Y) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true, ctx->HasOutput("Out"), true,
platform::errors::NotFound( platform::errors::NotFound("Output(Out) of MpcElementwiseAddOp should not be null."));
"Output(Out) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
ctx->GetInputDim("X").size(), ctx->GetInputDim("Y").size(), ctx->GetInputDim("X").size(), ctx->GetInputDim("Y").size(),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The dimensions of X should be greater than the dimensions of Y. " "The dimensions of X should be greater than the dimensions of Y. "
"But received the dimensions of X is [%s], the dimensions of Y is " "But received the dimensions of X is [%s], the dimensions of Y is [%s]",
"[%s]",
ctx->GetInputDim("X"), ctx->GetInputDim("Y"))); ctx->GetInputDim("X"), ctx->GetInputDim("Y")));
ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
class MpcElementwiseAddOpMaker : public framework::OpProtoAndCheckerMaker { class MpcElementwiseAddOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", AddInput("X", "(Tensor), The first input tensor of mpc elementwise add op.");
"(Tensor), The first input tensor of mpc elementwise add op."); AddInput("Y", "(Tensor), The second input tensor of mpc elementwise add op.");
AddInput("Y",
"(Tensor), The second input tensor of mpc elementwise add op.");
AddOutput("Out", "(Tensor), The output tensor of mpc elementwise add op."); AddOutput("Out", "(Tensor), The output tensor of mpc elementwise add op.");
AddAttr<int>("axis", AddAttr<int>("axis",
"(int, default -1). If X.dimension != Y.dimension," "(int, default -1). If X.dimension != Y.dimension,"
...@@ -92,24 +87,23 @@ public: ...@@ -92,24 +87,23 @@ public:
ctx->ShareLoD("Y", /*->*/ y_grad_name); ctx->ShareLoD("Y", /*->*/ y_grad_name);
} }
} }
}; };
template <typename T> template <typename T>
class MpcElementwiseAddOpGradMaker : public framework::SingleGradOpDescMaker { class MpcElementwiseAddOpGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad) const override {
std::unique_ptr<T> retv(new T()); grad->SetType("mpc_elementwise_add_grad");
retv->SetType("mpc_elementwise_add_grad"); grad->SetInput("X", this->Input("X"));
retv->SetInput("X", this->Input("X")); grad->SetInput("Y", this->Input("Y"));
retv->SetInput("Y", this->Input("Y")); grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); grad->SetAttrMap(this->Attrs());
retv->SetAttrMap(this->Attrs());
return retv;
} }
}; };
...@@ -127,6 +121,6 @@ REGISTER_OP_CPU_KERNEL( ...@@ -127,6 +121,6 @@ REGISTER_OP_CPU_KERNEL(
mpc_elementwise_add, mpc_elementwise_add,
ops::MpcElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::MpcElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_elementwise_add_grad, REGISTER_OP_CPU_KERNEL(
ops::MpcElementwiseAddGradKernel< mpc_elementwise_add_grad,
paddle::platform::CPUDeviceContext, int64_t>); ops::MpcElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
// This op is different with elementwise_add of PaddlePaddle. // This op is different with elementwise_add of PaddlePaddle.
// We only consider that the dimensions of X is equal with the dimensions of Y. // We only consider that the dimensions of X is equal with the dimensions of Y.
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#pragma once #pragma once
#include "mpc_op.h" #include "mpc_op.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,12 +25,12 @@ namespace operators { ...@@ -26,12 +25,12 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
// paddle/fluid/operators/elementwise/elementwise_op_function.h // paddle/fluid/operators/elementwise/elementwise_op_function.h
template <typename T, typename DeviceContext> class RowwiseTransformIterator; template <typename T, typename DeviceContext>
class RowwiseTransformIterator;
template <typename T> template <typename T>
class RowwiseTransformIterator<T, platform::CPUDeviceContext> class RowwiseTransformIterator<T, platform::CPUDeviceContext>
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t, : public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t, T *, T &> {
T *, T &> {
public: public:
RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {} RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
...@@ -54,13 +53,11 @@ public: ...@@ -54,13 +53,11 @@ public:
return *this; return *this;
} }
bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext> bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext> &rhs) const {
&rhs) const {
return (ptr_ + i_) == &(*rhs); return (ptr_ + i_) == &(*rhs);
} }
bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext> bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext> &rhs) const {
&rhs) const {
return (ptr_ + i_) != &(*rhs); return (ptr_ + i_) != &(*rhs);
} }
...@@ -72,15 +69,15 @@ private: ...@@ -72,15 +69,15 @@ private:
int64_t n_; int64_t n_;
}; };
template <typename T> struct AddFunctor { template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(T x, T y) { return x + y; } inline HOSTDEVICE T operator()(T x, T y) { return x + y; }
}; };
struct GetMidDims { struct GetMidDims {
inline HOSTDEVICE void operator()(const framework::DDim &x_dims, inline HOSTDEVICE void operator()(const framework::DDim &x_dims,
const framework::DDim &y_dims, const framework::DDim &y_dims, const int axis,
const int axis, int *pre, int *n, int *pre, int *n, int *post) {
int *post) {
*pre = 1; *pre = 1;
*n = 1; *n = 1;
*post = 1; *post = 1;
...@@ -105,18 +102,17 @@ const size_t SHARE_NUM = 2; ...@@ -105,18 +102,17 @@ const size_t SHARE_NUM = 2;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcElementwiseAddKernel : public MpcOpKernel<T> { class MpcElementwiseAddKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override{
auto *in_x_t = ctx.Input<Tensor>("X"); auto *in_x_t = ctx.Input<framework::LoDTensor>("X");
auto *in_y_t = ctx.Input<Tensor>("Y"); auto *in_y_t = ctx.Input<framework::LoDTensor>("Y");
auto *out_t = ctx.Output<Tensor>("Out"); auto *out_t = ctx.Output<framework::LoDTensor>("Out");
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
auto out = out_t->mutable_data<T>(ctx.GetPlace()); auto out = out_t->mutable_data<T>(ctx.GetPlace());
if (in_x_t->dims() == in_y_t->dims()) { if (in_x_t->dims() == in_y_t->dims()) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(in_x_t, in_y_t, out_t);
in_x_t, in_y_t, out_t);
} else { } else {
Tensor in_x_t_slice; Tensor in_x_t_slice;
Tensor in_y_t_slice; Tensor in_y_t_slice;
...@@ -137,8 +133,8 @@ public: ...@@ -137,8 +133,8 @@ public:
int pre, n, post; int pre, n, post;
GetMidDims get_mid_dims; GetMidDims get_mid_dims;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(post, 1,
post, 1, "post should be equal 1, but received post is [%s]", post); "post should be equal 1, but received post is [%s]", post);
auto x_ = in_x_t_slice.data<T>(); auto x_ = in_x_t_slice.data<T>();
auto y_ = in_y_t_slice.data<T>(); auto y_ = in_y_t_slice.data<T>();
...@@ -146,8 +142,8 @@ public: ...@@ -146,8 +142,8 @@ public:
auto nx_ = in_x_t_slice.numel(); auto nx_ = in_x_t_slice.numel();
paddle::platform::Transform<DeviceContext> trans; paddle::platform::Transform<DeviceContext> trans;
trans(ctx.template device_context<DeviceContext>(), x_, x_ + nx_, trans(ctx.template device_context<DeviceContext>(), x_, x_ + nx_,
RowwiseTransformIterator<T, DeviceContext>(y_, n), out_, RowwiseTransformIterator<T, DeviceContext>(y_, n),
AddFunctor<T>()); out_, AddFunctor<T>());
} }
} }
} }
...@@ -159,9 +155,9 @@ public: ...@@ -159,9 +155,9 @@ public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<framework::LoDTensor>("X"); auto *in_x_t = ctx.Input<framework::LoDTensor>("X");
auto *in_y_t = ctx.Input<framework::LoDTensor>("Y"); auto *in_y_t = ctx.Input<framework::LoDTensor>("Y");
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto *dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
auto dout_data = dout->data<T>(); auto dout_data = dout->data<T>();
...@@ -189,8 +185,8 @@ public: ...@@ -189,8 +185,8 @@ public:
int pre, n, post; int pre, n, post;
GetMidDims get_mid_dims; GetMidDims get_mid_dims;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(post, 1,
post, 1, "post should be equal 1, but received post is [%s]", post); "post should be equal 1, but received post is [%s]", post);
for (size_t i = 0; i < SHARE_NUM; ++i) { for (size_t i = 0; i < SHARE_NUM; ++i) {
int y_offset = i * n; int y_offset = i * n;
...@@ -212,3 +208,4 @@ public: ...@@ -212,3 +208,4 @@ public:
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#include "mpc_elementwise_sub_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "mpc_elementwise_sub_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -22,25 +22,21 @@ class MpcElementwiseSubOp : public framework::OperatorWithKernel { ...@@ -22,25 +22,21 @@ class MpcElementwiseSubOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true, ctx->HasInput("X"), true,
platform::errors::NotFound( platform::errors::NotFound("Input(X) of MpcElementwiseSubOp should not be null."));
"Input(X) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true, ctx->HasInput("Y"), true,
platform::errors::NotFound( platform::errors::NotFound("Input(Y) of MpcElementwiseSubOp should not be null."));
"Input(Y) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true, ctx->HasOutput("Out"), true,
platform::errors::NotFound( platform::errors::NotFound("Output(Out) of MpcElementwiseSubOp should not be null."));
"Output(Out) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->GetInputDim("X"), ctx->GetInputDim("Y"), ctx->GetInputDim("X"), ctx->GetInputDim("Y"),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The dimensions of X should be equal with the dimensions of Y. " "The dimensions of X should be equal with the dimensions of Y. "
"But received the dimensions of X is [%s], the dimensions of Y is " "But received the dimensions of X is [%s], the dimensions of Y is [%s]",
"[%s]",
ctx->GetInputDim("X"), ctx->GetInputDim("Y"))); ctx->GetInputDim("X"), ctx->GetInputDim("Y")));
ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareDim("X", /*->*/ "Out");
...@@ -51,10 +47,8 @@ public: ...@@ -51,10 +47,8 @@ public:
class MpcElementwiseSubOpMaker : public framework::OpProtoAndCheckerMaker { class MpcElementwiseSubOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", AddInput("X", "(Tensor), The first input tensor of mpc elementwise sub op.");
"(Tensor), The first input tensor of mpc elementwise sub op."); AddInput("Y", "(Tensor), The second input tensor of mpc elementwise sub op.");
AddInput("Y",
"(Tensor), The second input tensor of mpc elementwise sub op.");
AddOutput("Out", "(Tensor), The output tensor of mpc elementwise sub op."); AddOutput("Out", "(Tensor), The output tensor of mpc elementwise sub op.");
AddComment(R"DOC( AddComment(R"DOC(
MPC elementwise sub Operator. MPC elementwise sub Operator.
...@@ -86,21 +80,19 @@ public: ...@@ -86,21 +80,19 @@ public:
}; };
template <typename T> template <typename T>
class MpcElementwiseSubGradMaker : public framework::SingleGradOpDescMaker { class MpcElementwiseSubGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad) const override {
std::unique_ptr<T> retv(new T()); grad->SetType("mpc_elementwise_sub_grad");
retv->SetType("mpc_elementwise_sub_grad"); grad->SetInput("X", this->Input("X"));
retv->SetInput("X", this->Input("X")); grad->SetInput("Y", this->Input("Y"));
retv->SetInput("Y", this->Input("Y")); grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); grad->SetAttrMap(this->Attrs());
retv->SetAttrMap(this->Attrs());
return retv;
} }
}; };
...@@ -118,6 +110,6 @@ REGISTER_OP_CPU_KERNEL( ...@@ -118,6 +110,6 @@ REGISTER_OP_CPU_KERNEL(
mpc_elementwise_sub, mpc_elementwise_sub,
ops::MpcElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::MpcElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_elementwise_sub_grad, REGISTER_OP_CPU_KERNEL(
ops::MpcElementwiseSubGradKernel< mpc_elementwise_sub_grad,
paddle::platform::CPUDeviceContext, int64_t>); ops::MpcElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
// This op is different with elementwise_sub of PaddlePaddle. // This op is different with elementwise_sub of PaddlePaddle.
// We only consider that the dimensions of X is equal with the dimensions of Y. // We only consider that the dimensions of X is equal with the dimensions of Y.
#pragma once #pragma once
#include "mpc_op.h" #include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -33,8 +32,7 @@ public: ...@@ -33,8 +32,7 @@ public:
auto *out_t = ctx.Output<Tensor>("Out"); auto *out_t = ctx.Output<Tensor>("Out");
auto out = out_t->mutable_data<T>(ctx.GetPlace()); auto out = out_t->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(in_x_t, in_y_t, out_t);
in_x_t, in_y_t, out_t);
} }
}; };
...@@ -56,11 +54,11 @@ public: ...@@ -56,11 +54,11 @@ public:
} }
if (dy) { if (dy) {
dy->mutable_data<T>(ctx.GetPlace()); dy->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neg( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neg(dout, dy);
dout, dy);
} }
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
// Description: // Description:
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_config.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h" #include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_config.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,10 +26,10 @@ using mpc::Aby3Config; ...@@ -26,10 +26,10 @@ using mpc::Aby3Config;
class MpcInitOp : public framework::OperatorBase { class MpcInitOp : public framework::OperatorBase {
public: public:
MpcInitOp(const std::string &type, const framework::VariableNameMap &inputs, MpcInitOp(const std::string& type,
const framework::VariableNameMap &outputs, const framework::VariableNameMap& inputs,
const framework::AttributeMap &attrs) const framework::VariableNameMap& outputs,
: OperatorBase(type, inputs, outputs, attrs) {} const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
...@@ -55,24 +55,26 @@ public: ...@@ -55,24 +55,26 @@ public:
AddComment(R"DOC( AddComment(R"DOC(
Where2 Operator. Where2 Operator.
)DOC"); )DOC");
AddAttr<std::string>("protocol_name", "(string , default aby3)" AddAttr<std::string>("protocol_name",
"(string , default aby3)"
"protocol name") "protocol name")
.SetDefault({"aby3"}); .SetDefault({"aby3"});
AddAttr<int>("role", "trainer role.").SetDefault(0); AddAttr<int>("role", "trainer role.").SetDefault(0);
AddAttr<std::string>("local_addr", "(string, default localhost)" AddAttr<std::string>("local_addr",
"(string, default localhost)"
"local addr") "local addr")
.SetDefault({"localhost"}); .SetDefault({"localhost"});
AddAttr<std::string>("net_server_addr", "(string, default localhost)" AddAttr<std::string>("net_server_addr",
"(string, default localhost)"
"net server addr") "net server addr")
.SetDefault({"localhost"}); .SetDefault({"localhost"});
AddAttr<int>("net_server_port", "net server port, default to 6539.") AddAttr<int>("net_server_port", "net server port, default to 6539.").SetDefault(6539);
.SetDefault(6539);
} }
}; };
class MpcInitOpShapeInference : public framework::InferShapeBase { class MpcInitOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *ctx) const override {} void operator()(framework::InferShapeContext* ctx) const override {}
}; };
} // namespace operators } // namespace operators
...@@ -80,5 +82,7 @@ public: ...@@ -80,5 +82,7 @@ public:
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_init, ops::MpcInitOp, ops::MpcInitOpMaker, REGISTER_OPERATOR(
ops::MpcInitOpShapeInference); mpc_init, ops::MpcInitOp,
ops::MpcInitOpMaker, ops::MpcInitOpShapeInference);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#include "mpc_mean_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "mpc_mean_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -24,13 +24,13 @@ class MpcMeanOp : public framework::OperatorWithKernel { ...@@ -24,13 +24,13 @@ class MpcMeanOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx->HasInput("X"), true,
"Input(X) of MpcMeanOp should not be null.")); platform::errors::NotFound("Input(X) of MpcMeanOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx->HasOutput("Out"), true,
"Output(Out) of MpcMeanOp should not be null.")); platform::errors::NotFound("Output(Out) of MpcMeanOp should not be null."));
ctx->SetOutputDim("Out", {2, 1}); ctx->SetOutputDim("Out", {2, 1});
} }
}; };
...@@ -48,9 +48,10 @@ MPC mean Operator calculates the mean of all elements in X. ...@@ -48,9 +48,10 @@ MPC mean Operator calculates the mean of all elements in X.
class MpcMeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { class MpcMeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
GetInputOutputWithSameType() const override { const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}}; static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
} }
}; };
...@@ -63,21 +64,20 @@ public: ...@@ -63,21 +64,20 @@ public:
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X")); ctx->ShareLoD("X", framework::GradVarName("X"));
} }
}; };
template <typename T> template <typename T>
class MpcMeanOpGradMaker : public framework::SingleGradOpDescMaker { class MpcMeanOpGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad) const override {
std::unique_ptr<T> retv(new T()); grad->SetType("mpc_mean_grad");
retv->SetType("mpc_mean_grad"); grad->SetInput("X", this->Input("X"));
retv->SetInput("X", this->Input("X")); grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
return retv;
} }
}; };
...@@ -85,14 +85,16 @@ protected: ...@@ -85,14 +85,16 @@ protected:
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_mean, ops::MpcMeanOp, ops::MpcMeanOpMaker, REGISTER_OPERATOR(mpc_mean, ops::MpcMeanOp,
ops::MpcMeanOpMaker,
ops::MpcMeanOpInferVarType, ops::MpcMeanOpInferVarType,
ops::MpcMeanOpGradMaker<paddle::framework::OpDesc>); ops::MpcMeanOpGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_mean_grad, ops::MpcMeanGradOp); REGISTER_OPERATOR(mpc_mean_grad, ops::MpcMeanGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mpc_mean, ops::MpcMeanKernel<paddle::platform::CPUDeviceContext, int64_t>); mpc_mean,
ops::MpcMeanKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mpc_mean_grad, mpc_mean_grad,
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#pragma once #pragma once
#include "mpc_op.h" #include "mpc_op.h"
#include "paddle/fluid/framework/eigen.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -33,10 +32,8 @@ public: ...@@ -33,10 +32,8 @@ public:
auto *out_t = ctx.Output<Tensor>("Out"); auto *out_t = ctx.Output<Tensor>("Out");
out_t->mutable_data<T>(ctx.GetPlace()); out_t->mutable_data<T>(ctx.GetPlace());
double scale = 1.0 / (in_x_t->numel() / 2.0); double scale = 1.0 / (in_x_t->numel() / 2.0);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sum( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sum(in_x_t, out_t);
in_x_t, out_t); mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(out_t, scale, out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(
out_t, scale, out_t);
} }
}; };
...@@ -45,8 +42,7 @@ class MpcMeanGradKernel : public MpcOpKernel<T> { ...@@ -45,8 +42,7 @@ class MpcMeanGradKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(dout->numel() == 2, PADDLE_ENFORCE(dout->numel() == 2, "numel of MpcMean Gradient should be 2.");
"numel of MpcMean Gradient should be 2.");
auto dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dout_data = dout->data<T>(); auto dout_data = dout->data<T>();
...@@ -60,11 +56,11 @@ public: ...@@ -60,11 +56,11 @@ public:
} }
double scale_factor = 1.0 / (dx->numel() / 2); double scale_factor = 1.0 / (dx->numel() / 2);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(dx, scale_factor, dx);
dx, scale_factor, dx);
} }
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#include "mpc_mul_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "mpc_mul_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -24,16 +24,16 @@ class MpcMulOp : public framework::OperatorWithKernel { ...@@ -24,16 +24,16 @@ class MpcMulOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx->HasInput("X"), true,
"Input(X) of Mpc MulOp should not be null.")); platform::errors::NotFound("Input(X) of Mpc MulOp should not be null."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true, ctx->HasInput("Y"), true,
platform::errors::NotFound("Input(Y) of MpcMulOp should not be null.")); platform::errors::NotFound("Input(Y) of MpcMulOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx->HasOutput("Out"), true,
"Output(Out) of MpcMulOp should not be null.")); platform::errors::NotFound("Output(Out) of MpcMulOp should not be null."));
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
...@@ -86,8 +86,8 @@ public: ...@@ -86,8 +86,8 @@ public:
x_dims, x_mat_width, y_dims, y_mat_height)); x_dims, x_mat_width, y_dims, y_mat_height));
std::vector<int64_t> output_dims; std::vector<int64_t> output_dims;
output_dims.reserve(static_cast<size_t>(1 + x_num_col_dims + y_dims.size() - output_dims.reserve(
y_num_col_dims)); static_cast<size_t>(1 + x_num_col_dims + y_dims.size() - y_num_col_dims));
for (int i = 0; i <= x_num_col_dims; ++i) { // i=0, batch_size (share id) for (int i = 0; i <= x_num_col_dims; ++i) { // i=0, batch_size (share id)
output_dims.push_back(x_dims[i]); output_dims.push_back(x_dims[i]);
...@@ -153,7 +153,8 @@ public: ...@@ -153,7 +153,8 @@ public:
"same purpose as scale_weights in OPs that support quantization." "same purpose as scale_weights in OPs that support quantization."
"Only to be used with MKL-DNN INT8") "Only to be used with MKL-DNN INT8")
.SetDefault({1.0f}); .SetDefault({1.0f});
AddAttr<float>("scale_out", "scale_out to be used for int8 output data." AddAttr<float>("scale_out",
"scale_out to be used for int8 output data."
"Only used with MKL-DNN INT8") "Only used with MKL-DNN INT8")
.SetDefault(1.0f); .SetDefault(1.0f);
AddAttr<bool>( AddAttr<bool>(
...@@ -169,9 +170,10 @@ MPC mul Operator. ...@@ -169,9 +170,10 @@ MPC mul Operator.
class MpcMulOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { class MpcMulOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected: protected:
std::unordered_map<std::string, std::string> std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
GetInputOutputWithSameType() const override { const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}}; static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
} }
}; };
...@@ -202,36 +204,37 @@ public: ...@@ -202,36 +204,37 @@ public:
}; };
template <typename T> template <typename T>
class MpcMulOpGradMaker : public framework::SingleGradOpDescMaker { class MpcMulOpGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad) const override {
std::unique_ptr<T> retv(new T()); grad->SetType("mpc_mul_grad");
retv->SetType("mpc_mul_grad"); grad->SetInput("X", this->Input("X"));
retv->SetInput("X", this->Input("X")); grad->SetInput("Y", this->Input("Y"));
retv->SetInput("Y", this->Input("Y")); grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); grad->SetAttrMap(this->Attrs());
retv->SetAttrMap(this->Attrs());
return retv;
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_mul, ops::MpcMulOp, ops::MpcMulOpMaker, REGISTER_OPERATOR(mpc_mul, ops::MpcMulOp,
ops::MpcMulOpMaker,
ops::MpcMulOpInferVarType, ops::MpcMulOpInferVarType,
ops::MpcMulOpGradMaker<paddle::framework::OpDesc>); ops::MpcMulOpGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_mul_grad, ops::MpcMulGradOp); REGISTER_OPERATOR(mpc_mul_grad, ops::MpcMulGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mpc_mul, ops::MpcMulKernel<paddle::platform::CPUDeviceContext, int64_t>); mpc_mul,
ops::MpcMulKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mpc_mul_grad, mpc_mul_grad,
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#pragma once #pragma once
#include "mpc_op.h" #include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -48,7 +47,7 @@ public: ...@@ -48,7 +47,7 @@ public:
} }
for (size_t i = 1; i < y_dims.size(); i++) { for (size_t i = 1; i < y_dims.size(); i++) {
if (i <= y_num_col_dims) { if (i <= y_num_col_dims) {
x_mat_width *= y_dims[i]; y_mat_width *= y_dims[i];
} else { } else {
y_mat_height *= y_dims[i]; y_mat_height *= y_dims[i];
} }
...@@ -59,13 +58,8 @@ public: ...@@ -59,13 +58,8 @@ public:
x_matrix.ShareDataWith(*x); x_matrix.ShareDataWith(*x);
y_matrix.ShareDataWith(*y); y_matrix.ShareDataWith(*y);
if (x_dims.size() > 3) {
x_matrix.Resize({2, x_mat_width, x_mat_height}); x_matrix.Resize({2, x_mat_width, x_mat_height});
}
if (y_dims.size() > 3) {
y_matrix.Resize({2, y_mat_width, y_mat_height}); y_matrix.Resize({2, y_mat_width, y_mat_height});
}
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
...@@ -80,15 +74,17 @@ public: ...@@ -80,15 +74,17 @@ public:
if (out_dim.size() > 3) { if (out_dim.size() > 3) {
out->Resize(out_dim); out->Resize(out_dim);
} }
} }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcMulGradKernel : public MpcOpKernel<T> { class MpcMulGradKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<framework::LoDTensor>("X"); auto* x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y"); auto* y = ctx.Input<framework::LoDTensor>("Y");
auto *dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out")); auto *dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X")); auto *dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y")); auto *dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
...@@ -125,17 +121,9 @@ public: ...@@ -125,17 +121,9 @@ public:
y_matrix.ShareDataWith(*y); y_matrix.ShareDataWith(*y);
dout_matrix.ShareDataWith(*dout); dout_matrix.ShareDataWith(*dout);
if (x_dims.size() > 3) {
x_matrix.Resize({2, x_mat_width, x_mat_height}); x_matrix.Resize({2, x_mat_width, x_mat_height});
}
if (y_dims.size() > 3) {
y_matrix.Resize({2, y_mat_width, y_mat_height}); y_matrix.Resize({2, y_mat_width, y_mat_height});
}
if (dout_dims.size() > 3) {
dout_matrix.Resize({2, x_mat_width, y_mat_height}); dout_matrix.Resize({2, x_mat_width, y_mat_height});
}
if (dx != nullptr) { if (dx != nullptr) {
dx->set_lod(x->lod()); dx->set_lod(x->lod());
...@@ -149,15 +137,10 @@ public: ...@@ -149,15 +137,10 @@ public:
x_matrix_trans.mutable_data<T>(x->dims(), ctx.GetPlace()); x_matrix_trans.mutable_data<T>(x->dims(), ctx.GetPlace());
y_matrix_trans.mutable_data<T>(y->dims(), ctx.GetPlace()); y_matrix_trans.mutable_data<T>(y->dims(), ctx.GetPlace());
if (x_dims.size() >= 3) {
x_matrix_trans.Resize({2, x_mat_height, x_mat_width}); x_matrix_trans.Resize({2, x_mat_height, x_mat_width});
}
if (y_dims.size() >= 3) {
y_matrix_trans.Resize({2, y_mat_height, y_mat_width}); y_matrix_trans.Resize({2, y_mat_height, y_mat_width});
}
auto &dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
const int Rank = 3; const int Rank = 3;
Eigen::array<int, Rank> permute; Eigen::array<int, Rank> permute;
...@@ -172,7 +155,7 @@ public: ...@@ -172,7 +155,7 @@ public:
} }
auto eigen_in = framework::EigenTensor<T, Rank>::From(y_matrix); auto eigen_in = framework::EigenTensor<T, Rank>::From(y_matrix);
auto eigen_out = framework::EigenTensor<T, Rank>::From(y_matrix_trans); auto eigen_out = framework::EigenTensor<T, Rank>::From(y_matrix_trans);
auto *dev = dev_ctx.eigen_device(); auto* dev = dev_ctx.eigen_device();
eigen_out.device(*dev) = eigen_in.shuffle(permute); eigen_out.device(*dev) = eigen_in.shuffle(permute);
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N // dx = dout * y'. dx: M x K, dout : M x N, y : K x N
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
...@@ -191,7 +174,7 @@ public: ...@@ -191,7 +174,7 @@ public:
auto eigen_in = framework::EigenTensor<T, Rank>::From(x_matrix); auto eigen_in = framework::EigenTensor<T, Rank>::From(x_matrix);
auto eigen_out = framework::EigenTensor<T, Rank>::From(x_matrix_trans); auto eigen_out = framework::EigenTensor<T, Rank>::From(x_matrix_trans);
auto *dev = dev_ctx.eigen_device(); auto* dev = dev_ctx.eigen_device();
eigen_out.device(*dev) = eigen_in.shuffle(permute); eigen_out.device(*dev) = eigen_in.shuffle(permute);
// dy = x' * dout. dy K x N, dout : M x N, x : M x K // dy = x' * dout. dy K x N, dout : M x N, x : M x K
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
...@@ -206,3 +189,4 @@ public: ...@@ -206,3 +189,4 @@ public:
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
// Description: // Description:
#pragma once #pragma once
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h" #include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "core/privc3/circuit_context.h" #include "core/privc3/circuit_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> class MpcOpKernel : public framework::OpKernelBase { template <typename T>
class MpcOpKernel : public framework::OpKernelBase {
public: public:
using ELEMENT_TYPE = T; using ELEMENT_TYPE = T;
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_instance()->mpc_protocol(), PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_instance()->mpc_protocol(),
"Mpc protocol is not yet initialized in executor"); "Mpc protocol is not yet initialized in executor");
std::shared_ptr<aby3::CircuitContext> mpc_ctx( std::shared_ptr<aby3::CircuitContext> mpc_ctx(mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context());
mpc::ContextHolder::template run_with_context<>(&ctx, mpc_ctx, mpc::ContextHolder::template run_with_context<>(&ctx, mpc_ctx,
[&] { ComputeImpl(ctx); }); [&] { ComputeImpl(ctx); });
} }
virtual void ComputeImpl(const framework::ExecutionContext &ctx) const = 0; virtual void ComputeImpl(const framework::ExecutionContext& ctx) const = 0;
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -18,20 +18,20 @@ ...@@ -18,20 +18,20 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
// forward op defination //forward op defination
class MpcReluOp : public framework::OperatorWithKernel { class MpcReluOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim("X"); auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Y", in_dims); ctx->SetOutputDim("Y", in_dims);
} }
}; };
// forward input & output defination //forward input & output defination
class MpcReluOpMaker : public framework::OpProtoAndCheckerMaker { class MpcReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "The input tensor."); AddInput("X", "The input tensor.");
AddOutput("Y", "Output of relu_op"); AddOutput("Y", "Output of relu_op");
...@@ -41,31 +41,30 @@ Mpc Relu Operator. ...@@ -41,31 +41,30 @@ Mpc Relu Operator.
} }
}; };
// backward op defination //backward op defination
class MpcReluGradOp : public framework::OperatorWithKernel { class MpcReluGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim(framework::GradVarName("Y")); auto in_dims = ctx->GetInputDim(framework::GradVarName("Y"));
ctx->SetOutputDim(framework::GradVarName("X"), in_dims); ctx->SetOutputDim(framework::GradVarName("X"), in_dims);
} }
}; };
// backward type, input & output defination //backward type, input & output defination
template <typename T> template <typename T>
class MpcReluGradMaker : public framework::SingleGradOpDescMaker { class MpcReluGradMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override { protected:
auto *op = new T(); void Apply(GradOpPtr<T> grad) const override {
op->SetType("mpc_relu_grad"); grad->SetType("mpc_relu_grad");
op->SetInput("Y", this->Output("Y")); grad->SetInput("Y", this->Output("Y"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); grad->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetAttrMap(this->Attrs()); grad->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
return std::unique_ptr<T>(op);
} }
}; };
...@@ -76,8 +75,12 @@ namespace ops = paddle::operators; ...@@ -76,8 +75,12 @@ namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(mpc_relu, ops::MpcReluOp, ops::MpcReluOpMaker, REGISTER_OPERATOR(mpc_relu,
ops::MpcReluOp,
ops::MpcReluOpMaker,
ops::MpcReluGradMaker<paddle::framework::OpDesc>); ops::MpcReluGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_relu_grad, ops::MpcReluGradOp); REGISTER_OPERATOR(mpc_relu_grad, ops::MpcReluGradOp);
REGISTER_OP_CPU_KERNEL(mpc_relu, ops::MpcReluKernel<CPU, int64_t>); REGISTER_OP_CPU_KERNEL(mpc_relu,
REGISTER_OP_CPU_KERNEL(mpc_relu_grad, ops::MpcReluGradKernel<CPU, int64_t>); ops::MpcReluKernel<CPU, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_relu_grad,
ops::MpcReluGradKernel<CPU, int64_t>);
...@@ -14,43 +14,37 @@ ...@@ -14,43 +14,37 @@
#pragma once #pragma once
#include "mpc_op.h" #include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
// Define forward computation //Define forward computation
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcReluKernel : public MpcOpKernel<T> { class MpcReluKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext& ctx) const override {
const Tensor *in_t = ctx.Input<Tensor>("X"); const Tensor* in_t = ctx.Input<Tensor>("X");
Tensor *out_t = ctx.Output<Tensor>("Y"); Tensor* out_t = ctx.Output<Tensor>("Y");
auto x = in_t->data<T>(); auto x = in_t->data<T>();
auto y = out_t->mutable_data<T>(ctx.GetPlace()); auto y = out_t->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, "Protocol %s is not yet created in MPC Protocol.");
"Protocol %s is not yet created in MPC Protocol."); mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->relu(in_t,out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->relu(
in_t, out_t);
} }
}; };
// Define backward computation //Define backward computation
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcReluGradKernel : public MpcOpKernel<T> { class MpcReluGradKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext& ctx) const override {
auto *dy_t = ctx.Input<Tensor>(framework::GradVarName("Y")); auto* dy_t = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto *y_t = ctx.Input<Tensor>("Y"); auto* y_t = ctx.Input<Tensor>("Y");
auto *dx_t = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dx = dx_t->mutable_data<T>(ctx.GetPlace()); auto dx = dx_t->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance() mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->relu_grad(y_t, dy_t, dx_t, 0.0);
->mpc_protocol()
->mpc_operators()
->relu_grad(y_t, dy_t, dx_t, 0.0);
} }
}; };
} // namespace operaters }// namespace operaters
} // namespace paddle }// namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#include "mpc_sgd_op.h" #include "mpc_sgd_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -55,8 +55,8 @@ public: ...@@ -55,8 +55,8 @@ public:
} }
protected: protected:
framework::OpKernelType framework::OpKernelType GetExpectedKernelType(
GetExpectedKernelType(const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param"); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
...@@ -65,19 +65,19 @@ protected: ...@@ -65,19 +65,19 @@ protected:
class MpcSGDOpInferVarType : public framework::VarTypeInference { class MpcSGDOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(framework::InferVarTypeContext *ctx) const override { void operator()(framework::InferVarTypeContext *ctx) const override {
auto &input_var_n = ctx->Input("Param")[0]; auto in_var_type = ctx->GetInputType("Param");
auto in_var_type = ctx->GetType(input_var_n);
PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS || PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR, in_var_type == framework::proto::VarType::LOD_TENSOR,
"The input Var's type should be LoDtensor or SelectedRows," "The input Var's type should be LoDtensor or SelectedRows,"
" but the received var(%s)'s type is %s", " but the received var(%s)'s type is %s",
input_var_n, in_var_type); ctx->InputVarName("Param"), in_var_type);
ctx->SetOutputType("ParamOut", in_var_type);
for (auto &out_var_n : ctx->Output("ParamOut")) { //for (auto &out_var_n : framework::StaticGraphVarTypeInference::Output(ctx, "ParamOut")) {
if (ctx->GetType(out_var_n) != in_var_type) { // if (ctx->GetVarType(out_var_n) != in_var_type) {
ctx->SetType(out_var_n, in_var_type); // ctx->SetType(out_var_n, in_var_type);
} //}
} //}
} }
}; };
...@@ -108,7 +108,7 @@ $$param\_out = param - learning\_rate * grad$$ ...@@ -108,7 +108,7 @@ $$param\_out = param - learning\_rate * grad$$
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(
mpc_sgd, ops::MpcSGDOp, ops::MpcSGDOpMaker, mpc_sgd, ops::MpcSGDOp, ops::MpcSGDOpMaker,
// paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
ops::MpcSGDOpInferVarType); ops::MpcSGDOpInferVarType);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
mpc_sgd, ops::MpcSGDOpKernel<paddle::platform::CPUDeviceContext, int64_t>); mpc_sgd,
ops::MpcSGDOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#pragma once #pragma once
#include "mpc_op.h" #include "mpc_op.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MpcSGDOpKernel : public MpcOpKernel<T> { class MpcSGDOpKernel : public MpcOpKernel<T> {
public: public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override { void ComputeImpl(const framework::ExecutionContext &ctx) const override{
const auto *param_var = ctx.InputVar("Param"); const auto *param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true, PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
"The Var(%s)'s type should be LoDTensor, " "The Var(%s)'s type should be LoDTensor, "
"but the received is %s", "but the received is %s",
ctx.Inputs("Param").front(), ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type())); framework::ToTypeName(param_var->Type()));
const auto *grad_var = ctx.InputVar("Grad"); const auto *grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true, PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true,
"The Var(%s)'s type should be LoDTensor, " "The Var(%s)'s type should be LoDTensor, "
"but the received is %s", "but the received is %s",
ctx.Inputs("Grad").front(), ctx.InputNames("Grad").front(),
framework::ToTypeName(grad_var->Type())); framework::ToTypeName(grad_var->Type()));
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate"); const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
...@@ -49,19 +48,14 @@ public: ...@@ -49,19 +48,14 @@ public:
PADDLE_ENFORCE_EQ(grad->numel(), sz); PADDLE_ENFORCE_EQ(grad->numel(), sz);
const double *lr = learning_rate->data<double>(); const double *lr = learning_rate->data<double>();
// const T *param_data = param->data<T>();
// const T *grad_data = grad->data<T>();
T *out_data = param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, "Protocol %s is not yet created in MPC Protocol.");
"Protocol %s is not yet created in MPC Protocol.");
// update parameters // update parameters
framework::Tensor temp; framework::Tensor temp;
temp.mutable_data<T>(param->dims(), ctx.GetPlace()); temp.mutable_data<T>(param->dims(), ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(grad, lr[0], &temp);
grad, lr[0], &temp); mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(param, &temp, param_out);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(
param, &temp, param_out);
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -117,21 +117,19 @@ MpcSigmoidCrossEntropyWithLogits Operator. ...@@ -117,21 +117,19 @@ MpcSigmoidCrossEntropyWithLogits Operator.
}; };
template <typename T> template <typename T>
class MpcSigmoidCrossEntropyWithLogitsGradOpMaker : public framework::SingleGradOpDescMaker { class MpcSigmoidCrossEntropyWithLogitsGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad) const override {
std::unique_ptr<T> retv(new T()); grad->SetType("mpc_sigmoid_cross_entropy_with_logits_grad");
retv->SetType("mpc_sigmoid_cross_entropy_with_logits_grad"); grad->SetInput("X", this->Input("X"));
retv->SetInput("X", this->Input("X")); grad->SetInput("Label", this->Input("Label"));
retv->SetInput("Label", this->Input("Label")); grad->SetInput("Out", this->Output("Out"));
retv->SetInput("Out", this->Output("Out")); grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad->SetAttrMap(this->Attrs());
retv->SetAttrMap(this->Attrs());
return retv;
} }
}; };
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#include "mpc_square_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "mpc_square_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
class MpcSquareOp : public framework::OperatorWithKernel { class MpcSquareOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx->HasInput("X"), true,
"Input(X) of MpcSquareOp should not be null.")); platform::errors::NotFound("Input(X) of MpcSquareOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx->HasOutput("Out"), true,
"Output(Out) of MpcSquareOp should not be null.")); platform::errors::NotFound("Output(Out) of MpcSquareOp should not be null."));
ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
...@@ -59,26 +60,26 @@ public: ...@@ -59,26 +60,26 @@ public:
}; };
template <typename T> template <typename T>
class MpcSquareGradOpMaker : public framework::SingleGradOpDescMaker { class MpcSquareGradOpMaker : public framework::SingleGradOpMaker<T> {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> grad) const override {
std::unique_ptr<T> retv(new T()); grad->SetType("mpc_square_grad");
retv->SetType("mpc_square_grad"); grad->SetInput("X", this->Input("X"));
retv->SetInput("X", this->Input("X")); grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
return retv;
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_square, ops::MpcSquareOp, ops::MpcSquareOpMaker, REGISTER_OPERATOR(mpc_square, ops::MpcSquareOp,
ops::MpcSquareOpMaker,
ops::MpcSquareGradOpMaker<paddle::framework::OpDesc>); ops::MpcSquareGradOpMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_square_grad, ops::MpcSquareGradOp); REGISTER_OPERATOR(mpc_square_grad, ops::MpcSquareGradOp);
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#pragma once #pragma once
#include "mpc_op.h" #include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -27,8 +27,7 @@ public: ...@@ -27,8 +27,7 @@ public:
auto *in_x_t = ctx.Input<Tensor>("X"); auto *in_x_t = ctx.Input<Tensor>("X");
auto *out_t = ctx.Output<Tensor>("Out"); auto *out_t = ctx.Output<Tensor>("Out");
out_t->mutable_data<T>(ctx.GetPlace()); out_t->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul(in_x_t, in_x_t, out_t);
in_x_t, in_x_t, out_t);
} }
}; };
...@@ -43,13 +42,12 @@ public: ...@@ -43,13 +42,12 @@ public:
// allocate memory on device. // allocate memory on device.
dx_t->mutable_data<T>(ctx.GetPlace()); dx_t->mutable_data<T>(ctx.GetPlace());
// dx = dout * 2 * x // dx = dout * 2 * x
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale( mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(in_x_t, 2.0, dx_t);
in_x_t, 2.0, dx_t); mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul(dx_t, dout_t, dx_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul(
dx_t, dout_t, dx_t);
} }
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "mpc_sum_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_sum_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -31,11 +30,10 @@ class MpcSumOp : public framework::OperatorWithKernel { ...@@ -31,11 +30,10 @@ class MpcSumOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->HasInputs("X"), true, ctx->HasInputs("X"), true,
platform::errors::NotFound( platform::errors::NotFound("Input(X) of MpcElementwiseAddOp should not be null."));
"Input(X) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true, ctx->HasOutput("Out"), true,
...@@ -45,7 +43,8 @@ public: ...@@ -45,7 +43,8 @@ public:
auto x_dims = ctx->GetInputsDim("X"); auto x_dims = ctx->GetInputsDim("X");
auto N = x_dims.size(); auto N = x_dims.size();
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
N, 0, "ShapeError: The input tensor X's dimensions of SumOp " N, 0,
"ShapeError: The input tensor X's dimensions of SumOp "
"should be larger than 0. But received X's dimensions %d, " "should be larger than 0. But received X's dimensions %d, "
"X's shape = [%s].", "X's shape = [%s].",
N, &x_dims); N, &x_dims);
...@@ -55,7 +54,7 @@ public: ...@@ -55,7 +54,7 @@ public:
framework::DDim in_dim({0}); framework::DDim in_dim({0});
for (size_t i = 0; i < x_dims.size(); ++i) { for (size_t i = 0; i < x_dims.size(); ++i) {
auto &x_dim = x_dims[i]; auto& x_dim = x_dims[i];
// x_dim.size() == 1 means the real dim of selected rows is [0] // x_dim.size() == 1 means the real dim of selected rows is [0]
if (x_var_types[i] == framework::proto::VarType::SELECTED_ROWS && if (x_var_types[i] == framework::proto::VarType::SELECTED_ROWS &&
x_dim.size() == 1) { x_dim.size() == 1) {
...@@ -99,6 +98,7 @@ public: ...@@ -99,6 +98,7 @@ public:
ctx->SetOutputDim("Out", in_dim); ctx->SetOutputDim("Out", in_dim);
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
class MpcSumOpMaker : public framework::OpProtoAndCheckerMaker { class MpcSumOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -110,7 +110,8 @@ public: ...@@ -110,7 +110,8 @@ public:
"or LoDTensor, and data types can be: float32, float64, int32, " "or LoDTensor, and data types can be: float32, float64, int32, "
"int64.") "int64.")
.AsDuplicable(); .AsDuplicable();
AddOutput("Out", "the sum of input :code:`x`. its shape and data types are " AddOutput("Out",
"the sum of input :code:`x`. its shape and data types are "
"consistent with :code:`x`."); "consistent with :code:`x`.");
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
...@@ -121,6 +122,7 @@ public: ...@@ -121,6 +122,7 @@ public:
} }
}; };
class MpcSumGradMaker : public framework::GradOpDescMakerBase { class MpcSumGradMaker : public framework::GradOpDescMakerBase {
public: public:
using framework::GradOpDescMakerBase::GradOpDescMakerBase; using framework::GradOpDescMakerBase::GradOpDescMakerBase;
...@@ -131,8 +133,8 @@ public: ...@@ -131,8 +133,8 @@ public:
grad_ops.reserve(x_grads.size()); grad_ops.reserve(x_grads.size());
auto og = OutputGrad("Out"); auto og = OutputGrad("Out");
std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops), std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
[&og](const std::string &x_grad) { [&og](const std::string& x_grad) {
auto *grad_op = new framework::OpDesc(); auto* grad_op = new framework::OpDesc();
grad_op->SetType("scale"); grad_op->SetType("scale");
grad_op->SetInput("X", og); grad_op->SetInput("X", og);
grad_op->SetOutput("Out", {x_grad}); grad_op->SetOutput("Out", {x_grad});
...@@ -151,9 +153,10 @@ DECLARE_INPLACE_OP_INFERER(MpcSumInplace, {"X", "Out"}); ...@@ -151,9 +153,10 @@ DECLARE_INPLACE_OP_INFERER(MpcSumInplace, {"X", "Out"});
namespace ops = paddle::operators; namespace ops = paddle::operators;
// REGISTER_OP_WITHOUT_GRADIENT(mpc_sum, ops::MpcSumOp, ops::MpcSumOpMaker); //REGISTER_OP_WITHOUT_GRADIENT(mpc_sum, ops::MpcSumOp, ops::MpcSumOpMaker);
REGISTER_OPERATOR(mpc_sum, ops::MpcSumOp, ops::MpcSumOpMaker, REGISTER_OPERATOR(mpc_sum, ops::MpcSumOp,
ops::MpcSumGradMaker, ops::MpcSumInplace); ops::MpcSumOpMaker,
ops::MpcSumGradMaker,
ops::MpcSumInplace);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(mpc_sum, ops::MpcSumKernel<paddle::platform::CPUDeviceContext, int64_t>);
mpc_sum, ops::MpcSumKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#pragma once #pragma once
#include "mpc_op.h" #include "mpc_op.h"
...@@ -45,10 +45,7 @@ public: ...@@ -45,10 +45,7 @@ public:
auto &in_0 = in_vars[0]->Get<framework::LoDTensor>(); auto &in_0 = in_vars[0]->Get<framework::LoDTensor>();
auto &in_1 = in_vars[1]->Get<framework::LoDTensor>(); auto &in_1 = in_vars[1]->Get<framework::LoDTensor>();
if (in_0.numel() && in_1.numel()) { if (in_0.numel() && in_1.numel()) {
mpc::MpcInstance::mpc_instance() mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(&in_0, &in_1, out);
->mpc_protocol()
->mpc_operators()
->add(&in_0, &in_1, out);
start = 2; start = 2;
} }
} }
...@@ -66,15 +63,12 @@ public: ...@@ -66,15 +63,12 @@ public:
if (in_t.numel() == 0) { if (in_t.numel() == 0) {
continue; continue;
} }
mpc::MpcInstance::mpc_instance() mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(out, &in_t, out);
->mpc_protocol()
->mpc_operators()
->add(out, &in_t, out);
} else { } else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows."); PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
} }
} }
} else { }else {
PADDLE_THROW("Unexpected branch, output variable type is %s", PADDLE_THROW("Unexpected branch, output variable type is %s",
framework::ToTypeName(out_var->Type())); framework::ToTypeName(out_var->Type()));
} }
...@@ -82,3 +76,4 @@ public: ...@@ -82,3 +76,4 @@ public:
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
此差异已折叠。
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/framework/tensor.h"
namespace aby3 { namespace aby3 {
...@@ -28,10 +28,13 @@ using paddle::framework::Tensor; ...@@ -28,10 +28,13 @@ using paddle::framework::Tensor;
class PaddleTensorTest : public ::testing::Test { class PaddleTensorTest : public ::testing::Test {
public: public:
std::shared_ptr<TensorAdapterFactory> _tensor_factory; std::shared_ptr<TensorAdapterFactory> _tensor_factory;
CPUDeviceContext _cpu_ctx; CPUDeviceContext _cpu_ctx;
virtual ~PaddleTensorTest() noexcept {}
void SetUp() { void SetUp() {
_tensor_factory = std::make_shared<PaddleTensorFactory>(&_cpu_ctx); _tensor_factory = std::make_shared<PaddleTensorFactory>(&_cpu_ctx);
} }
...@@ -39,21 +42,20 @@ public: ...@@ -39,21 +42,20 @@ public:
TEST_F(PaddleTensorTest, factory_test) { TEST_F(PaddleTensorTest, factory_test) {
EXPECT_NO_THROW(_tensor_factory->template create<int64_t>()); EXPECT_NO_THROW(_tensor_factory->template create<int64_t>());
std::vector<size_t> shape = {2, 3}; std::vector<size_t> shape = { 2, 3 };
EXPECT_NO_THROW(_tensor_factory->template create<int64_t>(shape)); EXPECT_NO_THROW(_tensor_factory->template create<int64_t>(shape));
} }
TEST_F(PaddleTensorTest, ctor_test) { TEST_F(PaddleTensorTest, ctor_test) {
Tensor t; Tensor t;
// t holds no memory // t holds no memory
EXPECT_THROW({ PaddleTensor<int64_t> pt(&_cpu_ctx, t); }, EXPECT_THROW({ PaddleTensor<int64_t> pt(&_cpu_ctx, t); }, ::paddle::platform::EnforceNotMet);
::paddle::platform::EnforceNotMet);
t.template mutable_data<int64_t>(_cpu_ctx.GetPlace()); t.template mutable_data<int64_t>(_cpu_ctx.GetPlace());
EXPECT_NO_THROW({ PaddleTensor<int64_t> pt(&_cpu_ctx, t); }); EXPECT_NO_THROW({ PaddleTensor<int64_t> pt(&_cpu_ctx, t); });
} }
TEST_F(PaddleTensorTest, shape_test) { TEST_F(PaddleTensorTest, shape_test) {
std::vector<size_t> shape = {2, 3}; std::vector<size_t> shape = { 2, 3 };
auto pt = _tensor_factory->template create<int64_t>(shape); auto pt = _tensor_factory->template create<int64_t>(shape);
EXPECT_EQ(shape.size(), pt->shape().size()); EXPECT_EQ(shape.size(), pt->shape().size());
...@@ -65,7 +67,7 @@ TEST_F(PaddleTensorTest, shape_test) { ...@@ -65,7 +67,7 @@ TEST_F(PaddleTensorTest, shape_test) {
} }
TEST_F(PaddleTensorTest, reshape_test) { TEST_F(PaddleTensorTest, reshape_test) {
std::vector<size_t> shape = {2, 3}; std::vector<size_t> shape = { 2, 3 };
auto pt = _tensor_factory->template create<int64_t>(); auto pt = _tensor_factory->template create<int64_t>();
pt->reshape(shape); pt->reshape(shape);
...@@ -77,7 +79,7 @@ TEST_F(PaddleTensorTest, reshape_test) { ...@@ -77,7 +79,7 @@ TEST_F(PaddleTensorTest, reshape_test) {
} }
TEST_F(PaddleTensorTest, add_test) { TEST_F(PaddleTensorTest, add_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape); auto pt2 = _tensor_factory->template create<int64_t>(shape);
...@@ -89,7 +91,7 @@ TEST_F(PaddleTensorTest, add_test) { ...@@ -89,7 +91,7 @@ TEST_F(PaddleTensorTest, add_test) {
} }
TEST_F(PaddleTensorTest, sub_test) { TEST_F(PaddleTensorTest, sub_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape); auto pt2 = _tensor_factory->template create<int64_t>(shape);
...@@ -101,7 +103,7 @@ TEST_F(PaddleTensorTest, sub_test) { ...@@ -101,7 +103,7 @@ TEST_F(PaddleTensorTest, sub_test) {
} }
TEST_F(PaddleTensorTest, negative_test) { TEST_F(PaddleTensorTest, negative_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 2; pt0->data()[0] = 2;
...@@ -111,7 +113,7 @@ TEST_F(PaddleTensorTest, negative_test) { ...@@ -111,7 +113,7 @@ TEST_F(PaddleTensorTest, negative_test) {
} }
TEST_F(PaddleTensorTest, mul_test) { TEST_F(PaddleTensorTest, mul_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape); auto pt2 = _tensor_factory->template create<int64_t>(shape);
...@@ -123,7 +125,7 @@ TEST_F(PaddleTensorTest, mul_test) { ...@@ -123,7 +125,7 @@ TEST_F(PaddleTensorTest, mul_test) {
} }
TEST_F(PaddleTensorTest, div_test) { TEST_F(PaddleTensorTest, div_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape); auto pt2 = _tensor_factory->template create<int64_t>(shape);
...@@ -135,9 +137,9 @@ TEST_F(PaddleTensorTest, div_test) { ...@@ -135,9 +137,9 @@ TEST_F(PaddleTensorTest, div_test) {
} }
TEST_F(PaddleTensorTest, matmul_test) { TEST_F(PaddleTensorTest, matmul_test) {
std::vector<size_t> shape0 = {2, 3}; std::vector<size_t> shape0 = { 2, 3 };
std::vector<size_t> shape1 = {3, 2}; std::vector<size_t> shape1 = { 3, 2 };
std::vector<size_t> shape2 = {2, 2}; std::vector<size_t> shape2 = { 2, 2 };
auto pt0 = _tensor_factory->template create<int64_t>(shape0); auto pt0 = _tensor_factory->template create<int64_t>(shape0);
auto pt1 = _tensor_factory->template create<int64_t>(shape1); auto pt1 = _tensor_factory->template create<int64_t>(shape1);
auto pt2 = _tensor_factory->template create<int64_t>(shape2); auto pt2 = _tensor_factory->template create<int64_t>(shape2);
...@@ -151,7 +153,7 @@ TEST_F(PaddleTensorTest, matmul_test) { ...@@ -151,7 +153,7 @@ TEST_F(PaddleTensorTest, matmul_test) {
// | 3 4 5 | x | 2 3 | = | 28 40 | // | 3 4 5 | x | 2 3 | = | 28 40 |
// | 4 5 | // | 4 5 |
std::vector<int64_t> res = {10, 13, 28, 40}; std::vector<int64_t> res = { 10, 13, 28, 40 };
bool eq = std::equal(res.begin(), res.end(), pt2->data()); bool eq = std::equal(res.begin(), res.end(), pt2->data());
...@@ -159,7 +161,7 @@ TEST_F(PaddleTensorTest, matmul_test) { ...@@ -159,7 +161,7 @@ TEST_F(PaddleTensorTest, matmul_test) {
} }
TEST_F(PaddleTensorTest, xor_test) { TEST_F(PaddleTensorTest, xor_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape); auto pt2 = _tensor_factory->template create<int64_t>(shape);
...@@ -171,7 +173,7 @@ TEST_F(PaddleTensorTest, xor_test) { ...@@ -171,7 +173,7 @@ TEST_F(PaddleTensorTest, xor_test) {
} }
TEST_F(PaddleTensorTest, and_test) { TEST_F(PaddleTensorTest, and_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape); auto pt2 = _tensor_factory->template create<int64_t>(shape);
...@@ -183,7 +185,7 @@ TEST_F(PaddleTensorTest, and_test) { ...@@ -183,7 +185,7 @@ TEST_F(PaddleTensorTest, and_test) {
} }
TEST_F(PaddleTensorTest, or_test) { TEST_F(PaddleTensorTest, or_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape); auto pt2 = _tensor_factory->template create<int64_t>(shape);
...@@ -195,7 +197,7 @@ TEST_F(PaddleTensorTest, or_test) { ...@@ -195,7 +197,7 @@ TEST_F(PaddleTensorTest, or_test) {
} }
TEST_F(PaddleTensorTest, not_test) { TEST_F(PaddleTensorTest, not_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 0; pt0->data()[0] = 0;
...@@ -205,7 +207,7 @@ TEST_F(PaddleTensorTest, not_test) { ...@@ -205,7 +207,7 @@ TEST_F(PaddleTensorTest, not_test) {
} }
TEST_F(PaddleTensorTest, lshift_test) { TEST_F(PaddleTensorTest, lshift_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 2; pt0->data()[0] = 2;
...@@ -215,7 +217,7 @@ TEST_F(PaddleTensorTest, lshift_test) { ...@@ -215,7 +217,7 @@ TEST_F(PaddleTensorTest, lshift_test) {
} }
TEST_F(PaddleTensorTest, rshift_test) { TEST_F(PaddleTensorTest, rshift_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 2; pt0->data()[0] = 2;
...@@ -225,7 +227,7 @@ TEST_F(PaddleTensorTest, rshift_test) { ...@@ -225,7 +227,7 @@ TEST_F(PaddleTensorTest, rshift_test) {
} }
TEST_F(PaddleTensorTest, logical_rshift_test) { TEST_F(PaddleTensorTest, logical_rshift_test) {
std::vector<size_t> shape = {1}; std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape); auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape); auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = -1; pt0->data()[0] = -1;
...@@ -234,16 +236,17 @@ TEST_F(PaddleTensorTest, logical_rshift_test) { ...@@ -234,16 +236,17 @@ TEST_F(PaddleTensorTest, logical_rshift_test) {
EXPECT_EQ(-1ull >> 1, pt1->data()[0]); EXPECT_EQ(-1ull >> 1, pt1->data()[0]);
} }
TEST_F(PaddleTensorTest, scale_test) { TEST_F(PaddleTensorTest, scale_test) {
auto pt = _tensor_factory->template create<int64_t>(); auto pt = _tensor_factory->template create<int64_t>();
auto pt_ = dynamic_cast<PaddleTensor<int64_t> *>(pt.get()); auto pt_ = dynamic_cast<PaddleTensor<int64_t>*>(pt.get());
pt_->scaling_factor() = 1; pt_->scaling_factor() = 1;
Tensor t; Tensor t;
int dim[1] = {1}; int dim[1] = { 1 };
paddle::framework::DDim ddim(dim, 1); paddle::framework::DDim ddim(dim, 1);
t.template mutable_data<float>(ddim, _cpu_ctx.GetPlace()); t.template mutable_data<float>(ddim, _cpu_ctx.GetPlace());
...@@ -258,11 +261,11 @@ TEST_F(PaddleTensorTest, scale_test) { ...@@ -258,11 +261,11 @@ TEST_F(PaddleTensorTest, scale_test) {
TEST_F(PaddleTensorTest, scalar_test) { TEST_F(PaddleTensorTest, scalar_test) {
auto pt = _tensor_factory->template create<int64_t>(); auto pt = _tensor_factory->template create<int64_t>();
auto pt_ = dynamic_cast<PaddleTensor<int64_t> *>(pt.get()); auto pt_ = dynamic_cast<PaddleTensor<int64_t>*>(pt.get());
pt_->scaling_factor() = 1; pt_->scaling_factor() = 1;
std::vector<size_t> shape = {2}; std::vector<size_t> shape = { 2 };
pt_->template from_float_point_scalar(0.25f, shape, 2); pt_->template from_float_point_scalar(0.25f, shape, 2);
EXPECT_EQ(2, pt_->scaling_factor()); EXPECT_EQ(2, pt_->scaling_factor());
...@@ -271,11 +274,11 @@ TEST_F(PaddleTensorTest, scalar_test) { ...@@ -271,11 +274,11 @@ TEST_F(PaddleTensorTest, scalar_test) {
} }
TEST_F(PaddleTensorTest, slice_test) { TEST_F(PaddleTensorTest, slice_test) {
std::vector<size_t> shape = {2, 2}; std::vector<size_t> shape = { 2, 2 };
auto pt = _tensor_factory->template create<int64_t>(shape); auto pt = _tensor_factory->template create<int64_t>(shape);
auto ret = _tensor_factory->template create<int64_t>(); auto ret = _tensor_factory->template create<int64_t>();
auto pt_ = dynamic_cast<PaddleTensor<int64_t> *>(pt.get()); auto pt_ = dynamic_cast<PaddleTensor<int64_t>*>(pt.get());
pt_->scaling_factor() = 1; pt_->scaling_factor() = 1;
for (size_t i = 0; i < 4; ++i) { for (size_t i = 0; i < 4; ++i) {
......
...@@ -21,14 +21,13 @@ from paddle.fluid import core ...@@ -21,14 +21,13 @@ from paddle.fluid import core
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from paddle.fluid.framework import convert_np_dtype_to_dtype_ from paddle.fluid.framework import convert_np_dtype_to_dtype_
from paddle.fluid.data_feeder import check_type, check_dtype
class MpcVariable(Variable): class MpcVariable(Variable):
""" """
Extends from paddle.fluid.framework.Variable and rewrite Extends from paddle.fluid.framework.Variable and rewrite
the __init__ method where the shape is resized. the __init__ method where the shape is resized.
""" """
def __init__(self, def __init__(self,
block, block,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
...@@ -91,22 +90,22 @@ class MpcVariable(Variable): ...@@ -91,22 +90,22 @@ class MpcVariable(Variable):
else: else:
old_dtype = self.dtype old_dtype = self.dtype
if dtype != old_dtype: if dtype != old_dtype:
raise ValueError( raise ValueError("MpcVariable {0} has been created before. "
"MpcVariable {0} has been created before. "
"The previous data type is {1}; the new " "The previous data type is {1}; the new "
"data type is {2}. They are not " "data type is {2}. They are not "
"matched.".format(self.name, old_dtype, dtype)) "matched.".format(self.name, old_dtype,
dtype))
if lod_level is not None: if lod_level is not None:
if is_new_var: if is_new_var:
self.desc.set_lod_level(lod_level) self.desc.set_lod_level(lod_level)
else: else:
if lod_level != self.lod_level: if lod_level != self.lod_level:
raise ValueError( raise ValueError("MpcVariable {0} has been created before. "
"MpcVariable {0} has been created before. "
"The previous lod_level is {1}; the new " "The previous lod_level is {1}; the new "
"lod_level is {2}. They are not " "lod_level is {2}. They are not "
"matched".format(self.name, self.lod_level, lod_level)) "matched".format(self.name, self.lod_level,
lod_level))
if persistable is not None: if persistable is not None:
if is_new_var: if is_new_var:
self.desc.set_persistable(persistable) self.desc.set_persistable(persistable)
...@@ -156,8 +155,7 @@ class MpcParameter(MpcVariable): ...@@ -156,8 +155,7 @@ class MpcParameter(MpcVariable):
if len(shape) == 0: if len(shape) == 0:
raise ValueError( raise ValueError(
"The dimensions of shape for MpcParameter must be greater than 0" "The dimensions of shape for MpcParameter must be greater than 0")
)
for each in shape: for each in shape:
if each < 0: if each < 0:
...@@ -175,8 +173,7 @@ class MpcParameter(MpcVariable): ...@@ -175,8 +173,7 @@ class MpcParameter(MpcVariable):
**kwargs) **kwargs)
self.trainable = kwargs.get('trainable', True) self.trainable = kwargs.get('trainable', True)
self.optimize_attr = kwargs.get('optimize_attr', self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})
{'learning_rate': 1.0})
self.regularizer = kwargs.get('regularizer', None) self.regularizer = kwargs.get('regularizer', None)
...@@ -203,8 +200,8 @@ class MpcParameter(MpcVariable): ...@@ -203,8 +200,8 @@ class MpcParameter(MpcVariable):
additional_attr = ("trainable", "optimize_attr", "regularizer", additional_attr = ("trainable", "optimize_attr", "regularizer",
"gradient_clip_attr", "do_model_average") "gradient_clip_attr", "do_model_average")
for attr_name in additional_attr: for attr_name in additional_attr:
res_str += "%s: %s\n" % ( res_str += "%s: %s\n" % (attr_name,
attr_name, cpt.to_text(getattr(self, attr_name))) cpt.to_text(getattr(self, attr_name)))
else: else:
res_str = MpcVariable.to_string(self, throw_on_error, False) res_str = MpcVariable.to_string(self, throw_on_error, False)
return res_str return res_str
...@@ -245,8 +242,7 @@ def create_mpc_parameter(block, *args, **kwargs): ...@@ -245,8 +242,7 @@ def create_mpc_parameter(block, *args, **kwargs):
init_ops_len = len(init_ops) init_ops_len = len(init_ops)
if init_ops_len > 1: if init_ops_len > 1:
raise RuntimeError("mpc_param " + mpc_param.name + raise RuntimeError("mpc_param " + mpc_param.name +
" is inited by multiple init ops " + str( " is inited by multiple init ops " + str(init_ops))
init_ops))
elif init_ops_len == 1: elif init_ops_len == 1:
# TODO(Paddle 1.7): already inited, do nothing, should log a warning # TODO(Paddle 1.7): already inited, do nothing, should log a warning
pass pass
...@@ -272,7 +268,6 @@ def create_mpc_var(block, *args, **kwargs): ...@@ -272,7 +268,6 @@ def create_mpc_var(block, *args, **kwargs):
kwargs['initializer'](var, block) kwargs['initializer'](var, block)
return var return var
def is_mpc_parameter(var): def is_mpc_parameter(var):
""" """
Check whether the given variable is an instance of MpcParameter. Check whether the given variable is an instance of MpcParameter.
...@@ -282,4 +277,13 @@ def is_mpc_parameter(var): ...@@ -282,4 +277,13 @@ def is_mpc_parameter(var):
bool: True if the given `var` is an instance of Parameter, bool: True if the given `var` is an instance of Parameter,
False if not. False if not.
""" """
return isinstance(var, MpcParameter) return type(var) == MpcParameter
def check_mpc_variable_and_dtype(input,
input_name,
expected_dtype,
op_name,
extra_message=''):
check_type(input, input_name, MpcVariable, op_name, extra_message)
check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message)
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
""" """
basic mpc op layers. basic mpc op layers.
""" """
from paddle.fluid.data_feeder import check_type_and_dtype from paddle.fluid.data_feeder import check_variable_and_dtype
from ..framework import MpcVariable from ..framework import MpcVariable
from ..framework import check_mpc_variable_and_dtype
from ..mpc_layer_helper import MpcLayerHelper from ..mpc_layer_helper import MpcLayerHelper
__all__ = [ __all__ = [
...@@ -32,8 +33,8 @@ def _elementwise_op(helper): ...@@ -32,8 +33,8 @@ def _elementwise_op(helper):
assert x is not None, 'x cannot be None in {}'.format(op_type) assert x is not None, 'x cannot be None in {}'.format(op_type)
assert y is not None, 'y cannot be None in {}'.format(op_type) assert y is not None, 'y cannot be None in {}'.format(op_type)
check_type_and_dtype(x, 'x', MpcVariable, ['int64'], op_type) check_mpc_variable_and_dtype(x, 'x', ['int64'], op_type)
check_type_and_dtype(y, 'y', MpcVariable, ['int64'], op_type) check_mpc_variable_and_dtype(y, 'y', ['int64'], op_type)
axis = helper.kwargs.get('axis', -1) axis = helper.kwargs.get('axis', -1)
use_mkldnn = helper.kwargs.get('use_mkldnn', False) use_mkldnn = helper.kwargs.get('use_mkldnn', False)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
""" """
mpc math compare layers. mpc math compare layers.
""" """
from paddle.fluid.data_feeder import check_type_and_dtype
from ..framework import MpcVariable from ..framework import MpcVariable
from ..mpc_layer_helper import MpcLayerHelper from ..mpc_layer_helper import MpcLayerHelper
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
""" """
mpc math op layers. mpc math op layers.
""" """
from paddle.fluid.data_feeder import check_type_and_dtype
from ..framework import MpcVariable from ..framework import MpcVariable
from ..framework import check_mpc_variable_and_dtype
from ..mpc_layer_helper import MpcLayerHelper from ..mpc_layer_helper import MpcLayerHelper
__all__ = [ __all__ = [
...@@ -39,7 +39,7 @@ def mean(x, name=None): ...@@ -39,7 +39,7 @@ def mean(x, name=None):
Examples: todo Examples: todo
""" """
helper = MpcLayerHelper("mean", **locals()) helper = MpcLayerHelper("mean", **locals())
check_type_and_dtype(x, 'x', MpcVariable, ['int64'], 'mean') check_mpc_variable_and_dtype(x, 'x', ['int64'], 'mean')
if name is None: if name is None:
out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype) out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype)
else: else:
...@@ -64,7 +64,7 @@ def square(x, name=None): ...@@ -64,7 +64,7 @@ def square(x, name=None):
Examples: todo Examples: todo
""" """
helper = MpcLayerHelper("square", **locals()) helper = MpcLayerHelper("square", **locals())
check_type_and_dtype(x, 'x', MpcVariable, ['int64'], 'square') check_mpc_variable_and_dtype(x, 'x', ['int64'], 'square')
if name is None: if name is None:
out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype) out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype)
else: else:
...@@ -89,8 +89,7 @@ def sum(x): ...@@ -89,8 +89,7 @@ def sum(x):
Examples: todo Examples: todo
""" """
helper = MpcLayerHelper("sum", **locals()) helper = MpcLayerHelper("sum", **locals())
out = helper.create_mpc_variable_for_type_inference( out = helper.create_mpc_variable_for_type_inference(dtype=helper.input_dtype('x'))
dtype=helper.input_dtype('x'))
helper.append_op( helper.append_op(
type="mpc_sum", type="mpc_sum",
inputs={"X": x}, inputs={"X": x},
...@@ -116,16 +115,14 @@ def square_error_cost(input, label): ...@@ -116,16 +115,14 @@ def square_error_cost(input, label):
Examples: todo Examples: todo
""" """
helper = MpcLayerHelper('square_error_cost', **locals()) helper = MpcLayerHelper('square_error_cost', **locals())
minus_out = helper.create_mpc_variable_for_type_inference( minus_out = helper.create_mpc_variable_for_type_inference(dtype=input.dtype)
dtype=input.dtype)
helper.append_op( helper.append_op(
type='mpc_elementwise_sub', type='mpc_elementwise_sub',
inputs={'X': [input], inputs={'X': [input],
'Y': [label]}, 'Y': [label]},
outputs={'Out': [minus_out]}) outputs={'Out': [minus_out]})
square_out = helper.create_mpc_variable_for_type_inference( square_out = helper.create_mpc_variable_for_type_inference(dtype=input.dtype)
dtype=input.dtype)
helper.append_op( helper.append_op(
type='mpc_square', type='mpc_square',
inputs={'X': [minus_out]}, inputs={'X': [minus_out]},
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,12 +14,14 @@ ...@@ -14,12 +14,14 @@
""" """
mpc matrix op layers. mpc matrix op layers.
""" """
from paddle.fluid.data_feeder import check_type_and_dtype
from ..framework import MpcVariable from ..framework import MpcVariable
from ..framework import check_mpc_variable_and_dtype
from ..mpc_layer_helper import MpcLayerHelper from ..mpc_layer_helper import MpcLayerHelper
__all__ = ['mul', ] __all__ = [
'mul',
]
def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
...@@ -66,8 +68,8 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): ...@@ -66,8 +68,8 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
} }
helper = MpcLayerHelper("mul", **locals()) helper = MpcLayerHelper("mul", **locals())
check_type_and_dtype(x, 'x', MpcVariable, ['int64'], 'mul') check_mpc_variable_and_dtype(x, 'x', ['int64'], 'mul')
check_type_and_dtype(y, 'y', MpcVariable, ['int64'], 'mul') check_mpc_variable_and_dtype(y, 'y', ['int64'], 'mul')
if name is None: if name is None:
out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype) out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype)
else: else:
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -17,9 +17,9 @@ mpc ml op layers. ...@@ -17,9 +17,9 @@ mpc ml op layers.
from functools import reduce from functools import reduce
from paddle.fluid.data_feeder import check_type, check_dtype from paddle.fluid.data_feeder import check_type, check_dtype
from paddle.fluid.data_feeder import check_type_and_dtype
import numpy import numpy
from ..framework import MpcVariable from ..framework import MpcVariable
from ..framework import check_mpc_variable_and_dtype
from ..mpc_layer_helper import MpcLayerHelper from ..mpc_layer_helper import MpcLayerHelper
__all__ = [ __all__ = [
...@@ -30,9 +30,6 @@ __all__ = [ ...@@ -30,9 +30,6 @@ __all__ = [
] ]
# add softmax, relu
def fc(input, def fc(input,
size, size,
num_flatten_dims=1, num_flatten_dims=1,
...@@ -186,8 +183,7 @@ def softmax(input, use_cudnn=False, name=None, axis=-1): ...@@ -186,8 +183,7 @@ def softmax(input, use_cudnn=False, name=None, axis=-1):
""" """
attrs = {"axis": axis, "use_cudnn": use_cudnn} attrs = {"axis": axis, "use_cudnn": use_cudnn}
helper = MpcLayerHelper('softmax', **locals()) helper = MpcLayerHelper('softmax', **locals())
check_type_and_dtype(input, 'input', MpcVariable, check_mpc_variable_and_dtype(input, 'input', ['int64'], 'softmax')
['float16', 'float32', 'float64'], 'softmax')
dtype = helper.input_dtype() dtype = helper.input_dtype()
mpc_softmax_out = helper.create_mpc_variable_for_type_inference(dtype) mpc_softmax_out = helper.create_mpc_variable_for_type_inference(dtype)
...@@ -226,7 +222,9 @@ def relu(input, name=None): ...@@ -226,7 +222,9 @@ def relu(input, name=None):
dtype = helper.input_dtype(input_param_name='input') dtype = helper.input_dtype(input_param_name='input')
out = helper.create_mpc_variable_for_type_inference(dtype) out = helper.create_mpc_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="mpc_relu", inputs={"X": input}, outputs={"Y": out}) type="mpc_relu",
inputs={"X": input},
outputs={"Y": out})
return out return out
......
...@@ -32,7 +32,6 @@ def monkey_patch_mpc_variable(): ...@@ -32,7 +32,6 @@ def monkey_patch_mpc_variable():
Monkey patch for operator overloading. Monkey patch for operator overloading.
:return: :return:
""" """
def unique_tmp_name(): def unique_tmp_name():
""" """
Generate temp name for variable. Generate temp name for variable.
...@@ -80,7 +79,9 @@ def monkey_patch_mpc_variable(): ...@@ -80,7 +79,9 @@ def monkey_patch_mpc_variable():
tmp_name = unique_tmp_name() tmp_name = unique_tmp_name()
return block.create_var(name=tmp_name, dtype=dtype) return block.create_var(name=tmp_name, dtype=dtype)
def _elemwise_method_creator_(method_name, op_type, reverse=False): def _elemwise_method_creator_(method_name,
op_type,
reverse=False):
""" """
Operator overloading for different method. Operator overloading for different method.
:param method_name: the name of operator which is overloaded. :param method_name: the name of operator which is overloaded.
...@@ -88,19 +89,16 @@ def monkey_patch_mpc_variable(): ...@@ -88,19 +89,16 @@ def monkey_patch_mpc_variable():
:param reverse: :param reverse:
:return: :return:
""" """
def __impl__(self, other_var): def __impl__(self, other_var):
lhs_dtype = safe_get_dtype(self) lhs_dtype = safe_get_dtype(self)
if method_name in compare_ops: if method_name in compare_ops:
if not isinstance(other_var, Variable): if not isinstance(other_var, Variable):
raise NotImplementedError( raise NotImplementedError("Unsupported data type of {} for compare operations."
"Unsupported data type of {} for compare operations."
.format(other_var.name)) .format(other_var.name))
else: else:
if not isinstance(other_var, MpcVariable): if not isinstance(other_var, MpcVariable):
raise NotImplementedError( raise NotImplementedError("Unsupported data type of {}.".format(other_var.name))
"Unsupported data type of {}.".format(other_var.name))
rhs_dtype = safe_get_dtype(other_var) rhs_dtype = safe_get_dtype(other_var)
if reverse: if reverse:
...@@ -111,8 +109,7 @@ def monkey_patch_mpc_variable(): ...@@ -111,8 +109,7 @@ def monkey_patch_mpc_variable():
if method_name in compare_ops: if method_name in compare_ops:
out = create_new_tmp_var(current_block(self), dtype=rhs_dtype) out = create_new_tmp_var(current_block(self), dtype=rhs_dtype)
else: else:
out = create_new_tmp_mpc_var( out = create_new_tmp_mpc_var(current_block(self), dtype=lhs_dtype)
current_block(self), dtype=lhs_dtype)
# out = create_new_tmp_mpc_var(current_block(self), dtype=lhs_dtype) # out = create_new_tmp_mpc_var(current_block(self), dtype=lhs_dtype)
...@@ -179,10 +176,11 @@ def monkey_patch_mpc_variable(): ...@@ -179,10 +176,11 @@ def monkey_patch_mpc_variable():
("__lt__", "mpc_less_than", False), ("__lt__", "mpc_less_than", False),
("__le__", "mpc_less_equal", False), ("__le__", "mpc_less_equal", False),
("__gt__", "mpc_greater_than", False), ("__gt__", "mpc_greater_than", False),
("__ge__", "mpc_greater_equal", False)): ("__ge__", "mpc_greater_equal", False)
):
# Not support computation between MpcVariable and scalar. # Not support computation between MpcVariable and scalar.
setattr(MpcVariable, method_name, setattr(MpcVariable,
method_name,
_elemwise_method_creator_(method_name, op_type, reverse) _elemwise_method_creator_(method_name, op_type, reverse)
if method_name in supported_mpc_ops else announce_not_impl) if method_name in supported_mpc_ops else announce_not_impl)
# MpcVariable.astype = astype
...@@ -34,7 +34,7 @@ def python_version(): ...@@ -34,7 +34,7 @@ def python_version():
max_version, mid_version, min_version = python_version() max_version, mid_version, min_version = python_version()
REQUIRED_PACKAGES = [ REQUIRED_PACKAGES = [
'six >= 1.10.0', 'protobuf >= 3.1.0', 'paddlepaddle == 1.6.3', 'six >= 1.10.0', 'protobuf >= 3.1.0', 'paddlepaddle >= 1.8.0', 'paddlepaddle-gpu >= 1.8'
] ]
if max_version < 3: if max_version < 3:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册