提交 072ccc32 编写于 作者: J jhjiangcs

improve code to support PaddlePaddle1.8.0.

上级 639e920b
......@@ -34,8 +34,8 @@ execute_process(COMMAND ${PYTHON} -c "import paddle;print(paddle.version.full_ve
RESULT_VARIABLE ret OUTPUT_VARIABLE paddle_version OUTPUT_STRIP_TRAILING_WHITESPACE)
if (NOT ret)
if (NOT ${paddle_version} STREQUAL "1.6.3")
message(FATAL_ERROR "Paddle installation of 1.6.3 is required but ${paddle_version} is found")
if (NOT ${paddle_version} STREQUAL "1.8.0")
message(FATAL_ERROR "Paddle installation of 1.8.0 is required but ${paddle_version} is found")
endif()
else()
message(FATAL_ERROR "Could not get paddle version.")
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mpc_compare_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_compare_op.h"
namespace paddle {
namespace operators {
......@@ -25,16 +25,16 @@ class MpcCompareOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of MpcCompareOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Y"), true,
platform::errors::NotFound(
"Input(Y) of MpcCompareOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of MpcCompareOp should not be null."));
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of MpcCompareOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound("Input(Y) of MpcCompareOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of MpcCompareOp should not be null."));
auto dim_x = ctx->GetInputDim("X");
auto dim_y = ctx->GetInputDim("Y");
......@@ -45,11 +45,12 @@ public:
ctx->ShareLoD("Y", /*->*/ "Out");
}
framework::OpKernelType
GetExpectedKernelType(const framework::ExecutionContext &ctx) const override {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class MpcCompareOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -68,40 +69,27 @@ MPC Compare Operator.
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_than, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_greater_than,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::MpcGreaterThanFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_equal, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_greater_equal,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::MpcGreaterEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_less_than, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_less_than, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MpcLessThanFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_less_equal, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_less_equal, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MpcLessEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_equal, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_equal, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MpcEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_not_equal, ops::MpcCompareOp,
ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(
mpc_not_equal, ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext,
int64_t, ops::MpcNotEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_than, ops::MpcCompareOp, ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(mpc_greater_than,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcGreaterThanFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_greater_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(mpc_greater_equal,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcGreaterEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_less_than, ops::MpcCompareOp, ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(mpc_less_than,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcLessThanFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_less_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(mpc_less_equal,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcLessEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(mpc_equal,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcEqualFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(mpc_not_equal, ops::MpcCompareOp, ops::MpcCompareOpMaker);
REGISTER_OP_CPU_KERNEL(mpc_not_equal,
ops::MpcCompareOpKernel<paddle::platform::CPUDeviceContext, int64_t, ops::MpcNotEqualFunctor>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.uage governing permissions and
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include <math.h>
#include <type_traits>
namespace paddle {
namespace operators {
......@@ -25,50 +22,44 @@ using Tensor = framework::Tensor;
struct MpcGreaterThanFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->gt(
in_x_t, in_y_t, out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->gt(in_x_t, in_y_t, out_t);
}
};
struct MpcGreaterEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->geq(
in_x_t, in_y_t, out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->geq(in_x_t, in_y_t, out_t);
}
};
struct MpcLessThanFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->lt(
in_x_t, in_y_t, out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->lt(in_x_t, in_y_t, out_t);
}
};
struct MpcLessEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->leq(
in_x_t, in_y_t, out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->leq(in_x_t, in_y_t, out_t);
}
};
struct MpcEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->eq(
in_x_t, in_y_t, out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->eq(in_x_t, in_y_t, out_t);
}
};
struct MpcNotEqualFunctor {
void Run(const Tensor *in_x_t, const Tensor *in_y_t, Tensor *out_t) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neq(
in_x_t, in_y_t, out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neq(in_x_t, in_y_t, out_t);
}
};
template <typename DeviceContext, typename T, typename Functor>
class MpcCompareOpKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
void ComputeImpl(const framework::ExecutionContext &ctx) const override{
auto *in_x_t = ctx.Input<framework::LoDTensor>("X");
auto *in_y_t = ctx.Input<framework::LoDTensor>("Y");
auto *out_t = ctx.Output<framework::LoDTensor>("Out");
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mpc_elementwise_add_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_elementwise_add_op.h"
namespace paddle {
namespace operators {
......@@ -24,38 +24,33 @@ class MpcElementwiseAddOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of MpcElementwiseAddOp should not be null."));
platform::errors::NotFound("Input(X) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound(
"Input(Y) of MpcElementwiseAddOp should not be null."));
platform::errors::NotFound("Input(Y) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of MpcElementwiseAddOp should not be null."));
platform::errors::NotFound("Output(Out) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_GE(
ctx->GetInputDim("X").size(), ctx->GetInputDim("Y").size(),
platform::errors::InvalidArgument(
"The dimensions of X should be greater than the dimensions of Y. "
"But received the dimensions of X is [%s], the dimensions of Y is "
"[%s]",
"But received the dimensions of X is [%s], the dimensions of Y is [%s]",
ctx->GetInputDim("X"), ctx->GetInputDim("Y")));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class MpcElementwiseAddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), The first input tensor of mpc elementwise add op.");
AddInput("Y",
"(Tensor), The second input tensor of mpc elementwise add op.");
AddInput("X", "(Tensor), The first input tensor of mpc elementwise add op.");
AddInput("Y", "(Tensor), The second input tensor of mpc elementwise add op.");
AddOutput("Out", "(Tensor), The output tensor of mpc elementwise add op.");
AddAttr<int>("axis",
"(int, default -1). If X.dimension != Y.dimension,"
......@@ -92,24 +87,23 @@ public:
ctx->ShareLoD("Y", /*->*/ y_grad_name);
}
}
};
template <typename T>
class MpcElementwiseAddOpGradMaker : public framework::SingleGradOpDescMaker {
class MpcElementwiseAddOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
retv->SetType("mpc_elementwise_add_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Y", this->Input("Y"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetAttrMap(this->Attrs());
return retv;
void Apply(GradOpPtr<T> grad) const override {
grad->SetType("mpc_elementwise_add_grad");
grad->SetInput("X", this->Input("X"));
grad->SetInput("Y", this->Input("Y"));
grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
grad->SetAttrMap(this->Attrs());
}
};
......@@ -127,6 +121,6 @@ REGISTER_OP_CPU_KERNEL(
mpc_elementwise_add,
ops::MpcElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_elementwise_add_grad,
ops::MpcElementwiseAddGradKernel<
paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
mpc_elementwise_add_grad,
ops::MpcElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// This op is different with elementwise_add of PaddlePaddle.
// We only consider that the dimensions of X is equal with the dimensions of Y.
......@@ -18,7 +18,6 @@
#pragma once
#include "mpc_op.h"
#include "paddle/fluid/platform/transform.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle {
namespace operators {
......@@ -26,12 +25,12 @@ namespace operators {
using Tensor = framework::Tensor;
// paddle/fluid/operators/elementwise/elementwise_op_function.h
template <typename T, typename DeviceContext> class RowwiseTransformIterator;
template <typename T, typename DeviceContext>
class RowwiseTransformIterator;
template <typename T>
class RowwiseTransformIterator<T, platform::CPUDeviceContext>
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t,
T *, T &> {
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t, T *, T &> {
public:
RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
......@@ -54,13 +53,11 @@ public:
return *this;
}
bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>
&rhs) const {
bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext> &rhs) const {
return (ptr_ + i_) == &(*rhs);
}
bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>
&rhs) const {
bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext> &rhs) const {
return (ptr_ + i_) != &(*rhs);
}
......@@ -72,15 +69,15 @@ private:
int64_t n_;
};
template <typename T> struct AddFunctor {
template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(T x, T y) { return x + y; }
};
struct GetMidDims {
inline HOSTDEVICE void operator()(const framework::DDim &x_dims,
const framework::DDim &y_dims,
const int axis, int *pre, int *n,
int *post) {
const framework::DDim &y_dims, const int axis,
int *pre, int *n, int *post) {
*pre = 1;
*n = 1;
*post = 1;
......@@ -105,18 +102,17 @@ const size_t SHARE_NUM = 2;
template <typename DeviceContext, typename T>
class MpcElementwiseAddKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<Tensor>("X");
auto *in_y_t = ctx.Input<Tensor>("Y");
auto *out_t = ctx.Output<Tensor>("Out");
void ComputeImpl(const framework::ExecutionContext &ctx) const override{
auto *in_x_t = ctx.Input<framework::LoDTensor>("X");
auto *in_y_t = ctx.Input<framework::LoDTensor>("Y");
auto *out_t = ctx.Output<framework::LoDTensor>("Out");
int axis = ctx.Attr<int>("axis");
auto out = out_t->mutable_data<T>(ctx.GetPlace());
if (in_x_t->dims() == in_y_t->dims()) {
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(
in_x_t, in_y_t, out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(in_x_t, in_y_t, out_t);
} else {
Tensor in_x_t_slice;
Tensor in_y_t_slice;
......@@ -137,8 +133,8 @@ public:
int pre, n, post;
GetMidDims get_mid_dims;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
PADDLE_ENFORCE_EQ(
post, 1, "post should be equal 1, but received post is [%s]", post);
PADDLE_ENFORCE_EQ(post, 1,
"post should be equal 1, but received post is [%s]", post);
auto x_ = in_x_t_slice.data<T>();
auto y_ = in_y_t_slice.data<T>();
......@@ -146,8 +142,8 @@ public:
auto nx_ = in_x_t_slice.numel();
paddle::platform::Transform<DeviceContext> trans;
trans(ctx.template device_context<DeviceContext>(), x_, x_ + nx_,
RowwiseTransformIterator<T, DeviceContext>(y_, n), out_,
AddFunctor<T>());
RowwiseTransformIterator<T, DeviceContext>(y_, n),
out_, AddFunctor<T>());
}
}
}
......@@ -159,9 +155,9 @@ public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *in_x_t = ctx.Input<framework::LoDTensor>("X");
auto *in_y_t = ctx.Input<framework::LoDTensor>("Y");
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto *dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
auto dout_data = dout->data<T>();
......@@ -189,8 +185,8 @@ public:
int pre, n, post;
GetMidDims get_mid_dims;
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
PADDLE_ENFORCE_EQ(
post, 1, "post should be equal 1, but received post is [%s]", post);
PADDLE_ENFORCE_EQ(post, 1,
"post should be equal 1, but received post is [%s]", post);
for (size_t i = 0; i < SHARE_NUM; ++i) {
int y_offset = i * n;
......@@ -212,3 +208,4 @@ public:
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mpc_elementwise_sub_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_elementwise_sub_op.h"
namespace paddle {
namespace operators {
......@@ -22,25 +22,21 @@ class MpcElementwiseSubOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of MpcElementwiseSubOp should not be null."));
platform::errors::NotFound("Input(X) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound(
"Input(Y) of MpcElementwiseSubOp should not be null."));
platform::errors::NotFound("Input(Y) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of MpcElementwiseSubOp should not be null."));
platform::errors::NotFound("Output(Out) of MpcElementwiseSubOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("X"), ctx->GetInputDim("Y"),
platform::errors::InvalidArgument(
"The dimensions of X should be equal with the dimensions of Y. "
"But received the dimensions of X is [%s], the dimensions of Y is "
"[%s]",
"But received the dimensions of X is [%s], the dimensions of Y is [%s]",
ctx->GetInputDim("X"), ctx->GetInputDim("Y")));
ctx->ShareDim("X", /*->*/ "Out");
......@@ -51,10 +47,8 @@ public:
class MpcElementwiseSubOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), The first input tensor of mpc elementwise sub op.");
AddInput("Y",
"(Tensor), The second input tensor of mpc elementwise sub op.");
AddInput("X", "(Tensor), The first input tensor of mpc elementwise sub op.");
AddInput("Y", "(Tensor), The second input tensor of mpc elementwise sub op.");
AddOutput("Out", "(Tensor), The output tensor of mpc elementwise sub op.");
AddComment(R"DOC(
MPC elementwise sub Operator.
......@@ -86,21 +80,19 @@ public:
};
template <typename T>
class MpcElementwiseSubGradMaker : public framework::SingleGradOpDescMaker {
class MpcElementwiseSubGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
retv->SetType("mpc_elementwise_sub_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Y", this->Input("Y"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetAttrMap(this->Attrs());
return retv;
void Apply(GradOpPtr<T> grad) const override {
grad->SetType("mpc_elementwise_sub_grad");
grad->SetInput("X", this->Input("X"));
grad->SetInput("Y", this->Input("Y"));
grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
grad->SetAttrMap(this->Attrs());
}
};
......@@ -118,6 +110,6 @@ REGISTER_OP_CPU_KERNEL(
mpc_elementwise_sub,
ops::MpcElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_elementwise_sub_grad,
ops::MpcElementwiseSubGradKernel<
paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
mpc_elementwise_sub_grad,
ops::MpcElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// This op is different with elementwise_sub of PaddlePaddle.
// We only consider that the dimensions of X is equal with the dimensions of Y.
#pragma once
#include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle {
namespace operators {
......@@ -33,8 +32,7 @@ public:
auto *out_t = ctx.Output<Tensor>("Out");
auto out = out_t->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(
in_x_t, in_y_t, out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(in_x_t, in_y_t, out_t);
}
};
......@@ -56,11 +54,11 @@ public:
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neg(
dout, dy);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->neg(dout, dy);
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Description:
#include "paddle/fluid/framework/op_registry.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_config.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_config.h"
namespace paddle {
namespace operators {
......@@ -26,10 +26,10 @@ using mpc::Aby3Config;
class MpcInitOp : public framework::OperatorBase {
public:
MpcInitOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
MpcInitOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
......@@ -55,24 +55,26 @@ public:
AddComment(R"DOC(
Where2 Operator.
)DOC");
AddAttr<std::string>("protocol_name", "(string , default aby3)"
AddAttr<std::string>("protocol_name",
"(string , default aby3)"
"protocol name")
.SetDefault({"aby3"});
AddAttr<int>("role", "trainer role.").SetDefault(0);
AddAttr<std::string>("local_addr", "(string, default localhost)"
AddAttr<std::string>("local_addr",
"(string, default localhost)"
"local addr")
.SetDefault({"localhost"});
AddAttr<std::string>("net_server_addr", "(string, default localhost)"
AddAttr<std::string>("net_server_addr",
"(string, default localhost)"
"net server addr")
.SetDefault({"localhost"});
AddAttr<int>("net_server_port", "net server port, default to 6539.")
.SetDefault(6539);
AddAttr<int>("net_server_port", "net server port, default to 6539.").SetDefault(6539);
}
};
class MpcInitOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
......@@ -80,5 +82,7 @@ public:
namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_init, ops::MpcInitOp, ops::MpcInitOpMaker,
ops::MpcInitOpShapeInference);
REGISTER_OPERATOR(
mpc_init, ops::MpcInitOp,
ops::MpcInitOpMaker, ops::MpcInitOpShapeInference);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mpc_mean_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_mean_op.h"
namespace paddle {
namespace operators {
......@@ -24,13 +24,13 @@ class MpcMeanOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of MpcMeanOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of MpcMeanOp should not be null."));
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of MpcMeanOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of MpcMeanOp should not be null."));
ctx->SetOutputDim("Out", {2, 1});
}
};
......@@ -48,9 +48,10 @@ MPC mean Operator calculates the mean of all elements in X.
class MpcMeanOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>
GetInputOutputWithSameType() const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
......@@ -63,21 +64,20 @@ public:
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", framework::GradVarName("X"));
}
};
template <typename T>
class MpcMeanOpGradMaker : public framework::SingleGradOpDescMaker {
class MpcMeanOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
retv->SetType("mpc_mean_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
return retv;
void Apply(GradOpPtr<T> grad) const override {
grad->SetType("mpc_mean_grad");
grad->SetInput("X", this->Input("X"));
grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
......@@ -85,14 +85,16 @@ protected:
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_mean, ops::MpcMeanOp, ops::MpcMeanOpMaker,
REGISTER_OPERATOR(mpc_mean, ops::MpcMeanOp,
ops::MpcMeanOpMaker,
ops::MpcMeanOpInferVarType,
ops::MpcMeanOpGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_mean_grad, ops::MpcMeanGradOp);
REGISTER_OP_CPU_KERNEL(
mpc_mean, ops::MpcMeanKernel<paddle::platform::CPUDeviceContext, int64_t>);
mpc_mean,
ops::MpcMeanKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
mpc_mean_grad,
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "mpc_op.h"
#include "paddle/fluid/framework/eigen.h"
namespace paddle {
namespace operators {
......@@ -33,10 +32,8 @@ public:
auto *out_t = ctx.Output<Tensor>("Out");
out_t->mutable_data<T>(ctx.GetPlace());
double scale = 1.0 / (in_x_t->numel() / 2.0);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sum(
in_x_t, out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(
out_t, scale, out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sum(in_x_t, out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(out_t, scale, out_t);
}
};
......@@ -45,8 +42,7 @@ class MpcMeanGradKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(dout->numel() == 2,
"numel of MpcMean Gradient should be 2.");
PADDLE_ENFORCE(dout->numel() == 2, "numel of MpcMean Gradient should be 2.");
auto dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dout_data = dout->data<T>();
......@@ -60,11 +56,11 @@ public:
}
double scale_factor = 1.0 / (dx->numel() / 2);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(
dx, scale_factor, dx);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(dx, scale_factor, dx);
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mpc_mul_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_mul_op.h"
namespace paddle {
namespace operators {
......@@ -24,16 +24,16 @@ class MpcMulOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of Mpc MulOp should not be null."));
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of Mpc MulOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Y"), true,
platform::errors::NotFound("Input(Y) of MpcMulOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of MpcMulOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of MpcMulOp should not be null."));
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
......@@ -86,8 +86,8 @@ public:
x_dims, x_mat_width, y_dims, y_mat_height));
std::vector<int64_t> output_dims;
output_dims.reserve(static_cast<size_t>(1 + x_num_col_dims + y_dims.size() -
y_num_col_dims));
output_dims.reserve(
static_cast<size_t>(1 + x_num_col_dims + y_dims.size() - y_num_col_dims));
for (int i = 0; i <= x_num_col_dims; ++i) { // i=0, batch_size (share id)
output_dims.push_back(x_dims[i]);
......@@ -153,7 +153,8 @@ public:
"same purpose as scale_weights in OPs that support quantization."
"Only to be used with MKL-DNN INT8")
.SetDefault({1.0f});
AddAttr<float>("scale_out", "scale_out to be used for int8 output data."
AddAttr<float>("scale_out",
"scale_out to be used for int8 output data."
"Only used with MKL-DNN INT8")
.SetDefault(1.0f);
AddAttr<bool>(
......@@ -169,9 +170,10 @@ MPC mul Operator.
class MpcMulOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>
GetInputOutputWithSameType() const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
......@@ -202,36 +204,37 @@ public:
};
template <typename T>
class MpcMulOpGradMaker : public framework::SingleGradOpDescMaker {
class MpcMulOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
retv->SetType("mpc_mul_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Y", this->Input("Y"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetAttrMap(this->Attrs());
return retv;
void Apply(GradOpPtr<T> grad) const override {
grad->SetType("mpc_mul_grad");
grad->SetInput("X", this->Input("X"));
grad->SetInput("Y", this->Input("Y"));
grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
grad->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_mul, ops::MpcMulOp, ops::MpcMulOpMaker,
REGISTER_OPERATOR(mpc_mul, ops::MpcMulOp,
ops::MpcMulOpMaker,
ops::MpcMulOpInferVarType,
ops::MpcMulOpGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_mul_grad, ops::MpcMulGradOp);
REGISTER_OP_CPU_KERNEL(
mpc_mul, ops::MpcMulKernel<paddle::platform::CPUDeviceContext, int64_t>);
mpc_mul,
ops::MpcMulKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
mpc_mul_grad,
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle {
namespace operators {
......@@ -48,7 +47,7 @@ public:
}
for (size_t i = 1; i < y_dims.size(); i++) {
if (i <= y_num_col_dims) {
x_mat_width *= y_dims[i];
y_mat_width *= y_dims[i];
} else {
y_mat_height *= y_dims[i];
}
......@@ -59,13 +58,8 @@ public:
x_matrix.ShareDataWith(*x);
y_matrix.ShareDataWith(*y);
if (x_dims.size() > 3) {
x_matrix.Resize({2, x_mat_width, x_mat_height});
}
if (y_dims.size() > 3) {
y_matrix.Resize({2, y_mat_width, y_mat_height});
}
out->mutable_data<T>(ctx.GetPlace());
......@@ -80,15 +74,17 @@ public:
if (out_dim.size() > 3) {
out->Resize(out_dim);
}
}
};
template <typename DeviceContext, typename T>
class MpcMulGradKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<framework::LoDTensor>("X");
auto *y = ctx.Input<framework::LoDTensor>("Y");
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto *dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
......@@ -125,17 +121,9 @@ public:
y_matrix.ShareDataWith(*y);
dout_matrix.ShareDataWith(*dout);
if (x_dims.size() > 3) {
x_matrix.Resize({2, x_mat_width, x_mat_height});
}
if (y_dims.size() > 3) {
y_matrix.Resize({2, y_mat_width, y_mat_height});
}
if (dout_dims.size() > 3) {
dout_matrix.Resize({2, x_mat_width, y_mat_height});
}
if (dx != nullptr) {
dx->set_lod(x->lod());
......@@ -149,15 +137,10 @@ public:
x_matrix_trans.mutable_data<T>(x->dims(), ctx.GetPlace());
y_matrix_trans.mutable_data<T>(y->dims(), ctx.GetPlace());
if (x_dims.size() >= 3) {
x_matrix_trans.Resize({2, x_mat_height, x_mat_width});
}
if (y_dims.size() >= 3) {
y_matrix_trans.Resize({2, y_mat_height, y_mat_width});
}
auto &dev_ctx = ctx.template device_context<DeviceContext>();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
const int Rank = 3;
Eigen::array<int, Rank> permute;
......@@ -172,7 +155,7 @@ public:
}
auto eigen_in = framework::EigenTensor<T, Rank>::From(y_matrix);
auto eigen_out = framework::EigenTensor<T, Rank>::From(y_matrix_trans);
auto *dev = dev_ctx.eigen_device();
auto* dev = dev_ctx.eigen_device();
eigen_out.device(*dev) = eigen_in.shuffle(permute);
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
......@@ -191,7 +174,7 @@ public:
auto eigen_in = framework::EigenTensor<T, Rank>::From(x_matrix);
auto eigen_out = framework::EigenTensor<T, Rank>::From(x_matrix_trans);
auto *dev = dev_ctx.eigen_device();
auto* dev = dev_ctx.eigen_device();
eigen_out.device(*dev) = eigen_in.shuffle(permute);
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->matmul(
......@@ -206,3 +189,4 @@ public:
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Description:
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "core/privc3/circuit_context.h"
namespace paddle {
namespace operators {
template <typename T> class MpcOpKernel : public framework::OpKernelBase {
template <typename T>
class MpcOpKernel : public framework::OpKernelBase {
public:
using ELEMENT_TYPE = T;
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_instance()->mpc_protocol(),
"Mpc protocol is not yet initialized in executor");
std::shared_ptr<aby3::CircuitContext> mpc_ctx(
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context());
std::shared_ptr<aby3::CircuitContext> mpc_ctx(mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context());
mpc::ContextHolder::template run_with_context<>(&ctx, mpc_ctx,
[&] { ComputeImpl(ctx); });
}
virtual void ComputeImpl(const framework::ExecutionContext &ctx) const = 0;
virtual void ComputeImpl(const framework::ExecutionContext& ctx) const = 0;
};
} // namespace operators
} // namespace paddle
......@@ -18,20 +18,20 @@
namespace paddle {
namespace operators {
// forward op defination
//forward op defination
class MpcReluOp : public framework::OperatorWithKernel {
public:
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Y", in_dims);
}
};
// forward input & output defination
//forward input & output defination
class MpcReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
public:
void Make() override {
AddInput("X", "The input tensor.");
AddOutput("Y", "Output of relu_op");
......@@ -41,31 +41,30 @@ Mpc Relu Operator.
}
};
// backward op defination
//backward op defination
class MpcReluGradOp : public framework::OperatorWithKernel {
public:
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim(framework::GradVarName("Y"));
ctx->SetOutputDim(framework::GradVarName("X"), in_dims);
}
};
// backward type, input & output defination
//backward type, input & output defination
template <typename T>
class MpcReluGradMaker : public framework::SingleGradOpDescMaker {
class MpcReluGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
std::unique_ptr<T> Apply() const override {
auto *op = new T();
op->SetType("mpc_relu_grad");
op->SetInput("Y", this->Output("Y"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
return std::unique_ptr<T>(op);
protected:
void Apply(GradOpPtr<T> grad) const override {
grad->SetType("mpc_relu_grad");
grad->SetInput("Y", this->Output("Y"));
grad->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
grad->SetAttrMap(this->Attrs());
grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
......@@ -76,8 +75,12 @@ namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(mpc_relu, ops::MpcReluOp, ops::MpcReluOpMaker,
REGISTER_OPERATOR(mpc_relu,
ops::MpcReluOp,
ops::MpcReluOpMaker,
ops::MpcReluGradMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_relu_grad, ops::MpcReluGradOp);
REGISTER_OP_CPU_KERNEL(mpc_relu, ops::MpcReluKernel<CPU, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_relu_grad, ops::MpcReluGradKernel<CPU, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_relu,
ops::MpcReluKernel<CPU, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_relu_grad,
ops::MpcReluGradKernel<CPU, int64_t>);
......@@ -14,43 +14,37 @@
#pragma once
#include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
// Define forward computation
//Define forward computation
template <typename DeviceContext, typename T>
class MpcReluKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
const Tensor *in_t = ctx.Input<Tensor>("X");
Tensor *out_t = ctx.Output<Tensor>("Y");
void ComputeImpl(const framework::ExecutionContext& ctx) const override {
const Tensor* in_t = ctx.Input<Tensor>("X");
Tensor* out_t = ctx.Output<Tensor>("Y");
auto x = in_t->data<T>();
auto y = out_t->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol,
"Protocol %s is not yet created in MPC Protocol.");
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->relu(
in_t, out_t);
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, "Protocol %s is not yet created in MPC Protocol.");
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->relu(in_t,out_t);
}
};
// Define backward computation
//Define backward computation
template <typename DeviceContext, typename T>
class MpcReluGradKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
auto *dy_t = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto *y_t = ctx.Input<Tensor>("Y");
auto *dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
void ComputeImpl(const framework::ExecutionContext& ctx) const override {
auto* dy_t = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* y_t = ctx.Input<Tensor>("Y");
auto* dx_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dx = dx_t->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()
->mpc_protocol()
->mpc_operators()
->relu_grad(y_t, dy_t, dx_t, 0.0);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->relu_grad(y_t, dy_t, dx_t, 0.0);
}
};
} // namespace operaters
} // namespace paddle
}// namespace operaters
}// namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mpc_sgd_op.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -55,8 +55,8 @@ public:
}
protected:
framework::OpKernelType
GetExpectedKernelType(const framework::ExecutionContext &ctx) const override {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(data_type, ctx.device_context());
}
......@@ -65,19 +65,19 @@ protected:
class MpcSGDOpInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto &input_var_n = ctx->Input("Param")[0];
auto in_var_type = ctx->GetType(input_var_n);
auto in_var_type = ctx->GetInputType("Param");
PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR,
"The input Var's type should be LoDtensor or SelectedRows,"
" but the received var(%s)'s type is %s",
input_var_n, in_var_type);
ctx->InputVarName("Param"), in_var_type);
ctx->SetOutputType("ParamOut", in_var_type);
for (auto &out_var_n : ctx->Output("ParamOut")) {
if (ctx->GetType(out_var_n) != in_var_type) {
ctx->SetType(out_var_n, in_var_type);
}
}
//for (auto &out_var_n : framework::StaticGraphVarTypeInference::Output(ctx, "ParamOut")) {
// if (ctx->GetVarType(out_var_n) != in_var_type) {
// ctx->SetType(out_var_n, in_var_type);
//}
//}
}
};
......@@ -108,7 +108,7 @@ $$param\_out = param - learning\_rate * grad$$
namespace ops = paddle::operators;
REGISTER_OPERATOR(
mpc_sgd, ops::MpcSGDOp, ops::MpcSGDOpMaker,
// paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
ops::MpcSGDOpInferVarType);
REGISTER_OP_CPU_KERNEL(
mpc_sgd, ops::MpcSGDOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
mpc_sgd,
ops::MpcSGDOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "mpc_op.h"
#include "paddle/fluid/framework/eigen.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class MpcSGDOpKernel : public MpcOpKernel<T> {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override {
public:
void ComputeImpl(const framework::ExecutionContext &ctx) const override{
const auto *param_var = ctx.InputVar("Param");
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(),
ctx.InputNames("Param").front(),
framework::ToTypeName(param_var->Type()));
const auto *grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true,
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(),
ctx.InputNames("Grad").front(),
framework::ToTypeName(grad_var->Type()));
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
......@@ -49,19 +48,14 @@ public:
PADDLE_ENFORCE_EQ(grad->numel(), sz);
const double *lr = learning_rate->data<double>();
// const T *param_data = param->data<T>();
// const T *grad_data = grad->data<T>();
T *out_data = param_out->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol,
"Protocol %s is not yet created in MPC Protocol.");
param_out->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_protocol, "Protocol %s is not yet created in MPC Protocol.");
// update parameters
framework::Tensor temp;
temp.mutable_data<T>(param->dims(), ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(
grad, lr[0], &temp);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(
param, &temp, param_out);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(grad, lr[0], &temp);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->sub(param, &temp, param_out);
}
};
} // namespace operators
......
......@@ -117,21 +117,19 @@ MpcSigmoidCrossEntropyWithLogits Operator.
};
template <typename T>
class MpcSigmoidCrossEntropyWithLogitsGradOpMaker : public framework::SingleGradOpDescMaker {
class MpcSigmoidCrossEntropyWithLogitsGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
retv->SetType("mpc_sigmoid_cross_entropy_with_logits_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Label", this->Input("Label"));
retv->SetInput("Out", this->Output("Out"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
return retv;
void Apply(GradOpPtr<T> grad) const override {
grad->SetType("mpc_sigmoid_cross_entropy_with_logits_grad");
grad->SetInput("X", this->Input("X"));
grad->SetInput("Label", this->Input("Label"));
grad->SetInput("Out", this->Output("Out"));
grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad->SetAttrMap(this->Attrs());
}
};
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mpc_square_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_square_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class MpcSquareOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of MpcSquareOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of MpcSquareOp should not be null."));
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of MpcSquareOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of MpcSquareOp should not be null."));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
......@@ -59,26 +60,26 @@ public:
};
template <typename T>
class MpcSquareGradOpMaker : public framework::SingleGradOpDescMaker {
class MpcSquareGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> retv(new T());
retv->SetType("mpc_square_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
return retv;
void Apply(GradOpPtr<T> grad) const override {
grad->SetType("mpc_square_grad");
grad->SetInput("X", this->Input("X"));
grad->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(mpc_square, ops::MpcSquareOp, ops::MpcSquareOpMaker,
REGISTER_OPERATOR(mpc_square, ops::MpcSquareOp,
ops::MpcSquareOpMaker,
ops::MpcSquareGradOpMaker<paddle::framework::OpDesc>);
REGISTER_OPERATOR(mpc_square_grad, ops::MpcSquareGradOp);
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace paddle {
namespace operators {
......@@ -27,8 +27,7 @@ public:
auto *in_x_t = ctx.Input<Tensor>("X");
auto *out_t = ctx.Output<Tensor>("Out");
out_t->mutable_data<T>(ctx.GetPlace());
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul(
in_x_t, in_x_t, out_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul(in_x_t, in_x_t, out_t);
}
};
......@@ -43,13 +42,12 @@ public:
// allocate memory on device.
dx_t->mutable_data<T>(ctx.GetPlace());
// dx = dout * 2 * x
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(
in_x_t, 2.0, dx_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul(
dx_t, dout_t, dx_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->scale(in_x_t, 2.0, dx_t);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->mul(dx_t, dout_t, dx_t);
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "mpc_sum_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_sum_op.h"
namespace paddle {
namespace operators {
......@@ -31,11 +30,10 @@ class MpcSumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInputs("X"), true,
platform::errors::NotFound(
"Input(X) of MpcElementwiseAddOp should not be null."));
platform::errors::NotFound("Input(X) of MpcElementwiseAddOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
......@@ -45,7 +43,8 @@ public:
auto x_dims = ctx->GetInputsDim("X");
auto N = x_dims.size();
PADDLE_ENFORCE_GT(
N, 0, "ShapeError: The input tensor X's dimensions of SumOp "
N, 0,
"ShapeError: The input tensor X's dimensions of SumOp "
"should be larger than 0. But received X's dimensions %d, "
"X's shape = [%s].",
N, &x_dims);
......@@ -55,7 +54,7 @@ public:
framework::DDim in_dim({0});
for (size_t i = 0; i < x_dims.size(); ++i) {
auto &x_dim = x_dims[i];
auto& x_dim = x_dims[i];
// x_dim.size() == 1 means the real dim of selected rows is [0]
if (x_var_types[i] == framework::proto::VarType::SELECTED_ROWS &&
x_dim.size() == 1) {
......@@ -99,6 +98,7 @@ public:
ctx->SetOutputDim("Out", in_dim);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class MpcSumOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -110,7 +110,8 @@ public:
"or LoDTensor, and data types can be: float32, float64, int32, "
"int64.")
.AsDuplicable();
AddOutput("Out", "the sum of input :code:`x`. its shape and data types are "
AddOutput("Out",
"the sum of input :code:`x`. its shape and data types are "
"consistent with :code:`x`.");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
......@@ -121,6 +122,7 @@ public:
}
};
class MpcSumGradMaker : public framework::GradOpDescMakerBase {
public:
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
......@@ -131,8 +133,8 @@ public:
grad_ops.reserve(x_grads.size());
auto og = OutputGrad("Out");
std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
[&og](const std::string &x_grad) {
auto *grad_op = new framework::OpDesc();
[&og](const std::string& x_grad) {
auto* grad_op = new framework::OpDesc();
grad_op->SetType("scale");
grad_op->SetInput("X", og);
grad_op->SetOutput("Out", {x_grad});
......@@ -151,9 +153,10 @@ DECLARE_INPLACE_OP_INFERER(MpcSumInplace, {"X", "Out"});
namespace ops = paddle::operators;
// REGISTER_OP_WITHOUT_GRADIENT(mpc_sum, ops::MpcSumOp, ops::MpcSumOpMaker);
REGISTER_OPERATOR(mpc_sum, ops::MpcSumOp, ops::MpcSumOpMaker,
ops::MpcSumGradMaker, ops::MpcSumInplace);
//REGISTER_OP_WITHOUT_GRADIENT(mpc_sum, ops::MpcSumOp, ops::MpcSumOpMaker);
REGISTER_OPERATOR(mpc_sum, ops::MpcSumOp,
ops::MpcSumOpMaker,
ops::MpcSumGradMaker,
ops::MpcSumInplace);
REGISTER_OP_CPU_KERNEL(
mpc_sum, ops::MpcSumKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(mpc_sum, ops::MpcSumKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "mpc_op.h"
......@@ -45,10 +45,7 @@ public:
auto &in_0 = in_vars[0]->Get<framework::LoDTensor>();
auto &in_1 = in_vars[1]->Get<framework::LoDTensor>();
if (in_0.numel() && in_1.numel()) {
mpc::MpcInstance::mpc_instance()
->mpc_protocol()
->mpc_operators()
->add(&in_0, &in_1, out);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(&in_0, &in_1, out);
start = 2;
}
}
......@@ -66,15 +63,12 @@ public:
if (in_t.numel() == 0) {
continue;
}
mpc::MpcInstance::mpc_instance()
->mpc_protocol()
->mpc_operators()
->add(out, &in_t, out);
mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_operators()->add(out, &in_t, out);
} else {
PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
}
} else {
}else {
PADDLE_THROW("Unexpected branch, output variable type is %s",
framework::ToTypeName(out_var->Type()));
}
......@@ -82,3 +76,4 @@ public:
};
} // namespace operators
} // namespace paddle
此差异已折叠。
......@@ -18,8 +18,8 @@
#include "gtest/gtest.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/framework/tensor.h"
namespace aby3 {
......@@ -28,10 +28,13 @@ using paddle::framework::Tensor;
class PaddleTensorTest : public ::testing::Test {
public:
std::shared_ptr<TensorAdapterFactory> _tensor_factory;
CPUDeviceContext _cpu_ctx;
virtual ~PaddleTensorTest() noexcept {}
void SetUp() {
_tensor_factory = std::make_shared<PaddleTensorFactory>(&_cpu_ctx);
}
......@@ -39,21 +42,20 @@ public:
TEST_F(PaddleTensorTest, factory_test) {
EXPECT_NO_THROW(_tensor_factory->template create<int64_t>());
std::vector<size_t> shape = {2, 3};
std::vector<size_t> shape = { 2, 3 };
EXPECT_NO_THROW(_tensor_factory->template create<int64_t>(shape));
}
TEST_F(PaddleTensorTest, ctor_test) {
Tensor t;
// t holds no memory
EXPECT_THROW({ PaddleTensor<int64_t> pt(&_cpu_ctx, t); },
::paddle::platform::EnforceNotMet);
EXPECT_THROW({ PaddleTensor<int64_t> pt(&_cpu_ctx, t); }, ::paddle::platform::EnforceNotMet);
t.template mutable_data<int64_t>(_cpu_ctx.GetPlace());
EXPECT_NO_THROW({ PaddleTensor<int64_t> pt(&_cpu_ctx, t); });
}
TEST_F(PaddleTensorTest, shape_test) {
std::vector<size_t> shape = {2, 3};
std::vector<size_t> shape = { 2, 3 };
auto pt = _tensor_factory->template create<int64_t>(shape);
EXPECT_EQ(shape.size(), pt->shape().size());
......@@ -65,7 +67,7 @@ TEST_F(PaddleTensorTest, shape_test) {
}
TEST_F(PaddleTensorTest, reshape_test) {
std::vector<size_t> shape = {2, 3};
std::vector<size_t> shape = { 2, 3 };
auto pt = _tensor_factory->template create<int64_t>();
pt->reshape(shape);
......@@ -77,7 +79,7 @@ TEST_F(PaddleTensorTest, reshape_test) {
}
TEST_F(PaddleTensorTest, add_test) {
std::vector<size_t> shape = {1};
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
......@@ -89,7 +91,7 @@ TEST_F(PaddleTensorTest, add_test) {
}
TEST_F(PaddleTensorTest, sub_test) {
std::vector<size_t> shape = {1};
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
......@@ -101,7 +103,7 @@ TEST_F(PaddleTensorTest, sub_test) {
}
TEST_F(PaddleTensorTest, negative_test) {
std::vector<size_t> shape = {1};
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 2;
......@@ -111,7 +113,7 @@ TEST_F(PaddleTensorTest, negative_test) {
}
TEST_F(PaddleTensorTest, mul_test) {
std::vector<size_t> shape = {1};
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
......@@ -123,7 +125,7 @@ TEST_F(PaddleTensorTest, mul_test) {
}
TEST_F(PaddleTensorTest, div_test) {
std::vector<size_t> shape = {1};
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
......@@ -135,9 +137,9 @@ TEST_F(PaddleTensorTest, div_test) {
}
TEST_F(PaddleTensorTest, matmul_test) {
std::vector<size_t> shape0 = {2, 3};
std::vector<size_t> shape1 = {3, 2};
std::vector<size_t> shape2 = {2, 2};
std::vector<size_t> shape0 = { 2, 3 };
std::vector<size_t> shape1 = { 3, 2 };
std::vector<size_t> shape2 = { 2, 2 };
auto pt0 = _tensor_factory->template create<int64_t>(shape0);
auto pt1 = _tensor_factory->template create<int64_t>(shape1);
auto pt2 = _tensor_factory->template create<int64_t>(shape2);
......@@ -151,7 +153,7 @@ TEST_F(PaddleTensorTest, matmul_test) {
// | 3 4 5 | x | 2 3 | = | 28 40 |
// | 4 5 |
std::vector<int64_t> res = {10, 13, 28, 40};
std::vector<int64_t> res = { 10, 13, 28, 40 };
bool eq = std::equal(res.begin(), res.end(), pt2->data());
......@@ -159,7 +161,7 @@ TEST_F(PaddleTensorTest, matmul_test) {
}
TEST_F(PaddleTensorTest, xor_test) {
std::vector<size_t> shape = {1};
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
......@@ -171,7 +173,7 @@ TEST_F(PaddleTensorTest, xor_test) {
}
TEST_F(PaddleTensorTest, and_test) {
std::vector<size_t> shape = {1};
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
......@@ -183,7 +185,7 @@ TEST_F(PaddleTensorTest, and_test) {
}
TEST_F(PaddleTensorTest, or_test) {
std::vector<size_t> shape = {1};
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
auto pt2 = _tensor_factory->template create<int64_t>(shape);
......@@ -195,7 +197,7 @@ TEST_F(PaddleTensorTest, or_test) {
}
TEST_F(PaddleTensorTest, not_test) {
std::vector<size_t> shape = {1};
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 0;
......@@ -205,7 +207,7 @@ TEST_F(PaddleTensorTest, not_test) {
}
TEST_F(PaddleTensorTest, lshift_test) {
std::vector<size_t> shape = {1};
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 2;
......@@ -215,7 +217,7 @@ TEST_F(PaddleTensorTest, lshift_test) {
}
TEST_F(PaddleTensorTest, rshift_test) {
std::vector<size_t> shape = {1};
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = 2;
......@@ -225,7 +227,7 @@ TEST_F(PaddleTensorTest, rshift_test) {
}
TEST_F(PaddleTensorTest, logical_rshift_test) {
std::vector<size_t> shape = {1};
std::vector<size_t> shape = { 1 };
auto pt0 = _tensor_factory->template create<int64_t>(shape);
auto pt1 = _tensor_factory->template create<int64_t>(shape);
pt0->data()[0] = -1;
......@@ -234,16 +236,17 @@ TEST_F(PaddleTensorTest, logical_rshift_test) {
EXPECT_EQ(-1ull >> 1, pt1->data()[0]);
}
TEST_F(PaddleTensorTest, scale_test) {
auto pt = _tensor_factory->template create<int64_t>();
auto pt_ = dynamic_cast<PaddleTensor<int64_t> *>(pt.get());
auto pt_ = dynamic_cast<PaddleTensor<int64_t>*>(pt.get());
pt_->scaling_factor() = 1;
Tensor t;
int dim[1] = {1};
int dim[1] = { 1 };
paddle::framework::DDim ddim(dim, 1);
t.template mutable_data<float>(ddim, _cpu_ctx.GetPlace());
......@@ -258,11 +261,11 @@ TEST_F(PaddleTensorTest, scale_test) {
TEST_F(PaddleTensorTest, scalar_test) {
auto pt = _tensor_factory->template create<int64_t>();
auto pt_ = dynamic_cast<PaddleTensor<int64_t> *>(pt.get());
auto pt_ = dynamic_cast<PaddleTensor<int64_t>*>(pt.get());
pt_->scaling_factor() = 1;
std::vector<size_t> shape = {2};
std::vector<size_t> shape = { 2 };
pt_->template from_float_point_scalar(0.25f, shape, 2);
EXPECT_EQ(2, pt_->scaling_factor());
......@@ -271,11 +274,11 @@ TEST_F(PaddleTensorTest, scalar_test) {
}
TEST_F(PaddleTensorTest, slice_test) {
std::vector<size_t> shape = {2, 2};
std::vector<size_t> shape = { 2, 2 };
auto pt = _tensor_factory->template create<int64_t>(shape);
auto ret = _tensor_factory->template create<int64_t>();
auto pt_ = dynamic_cast<PaddleTensor<int64_t> *>(pt.get());
auto pt_ = dynamic_cast<PaddleTensor<int64_t>*>(pt.get());
pt_->scaling_factor() = 1;
for (size_t i = 0; i < 4; ++i) {
......
......@@ -21,14 +21,13 @@ from paddle.fluid import core
from paddle.fluid import unique_name
from paddle.fluid.framework import Variable
from paddle.fluid.framework import convert_np_dtype_to_dtype_
from paddle.fluid.data_feeder import check_type, check_dtype
class MpcVariable(Variable):
"""
Extends from paddle.fluid.framework.Variable and rewrite
the __init__ method where the shape is resized.
"""
def __init__(self,
block,
type=core.VarDesc.VarType.LOD_TENSOR,
......@@ -91,22 +90,22 @@ class MpcVariable(Variable):
else:
old_dtype = self.dtype
if dtype != old_dtype:
raise ValueError(
"MpcVariable {0} has been created before. "
raise ValueError("MpcVariable {0} has been created before. "
"The previous data type is {1}; the new "
"data type is {2}. They are not "
"matched.".format(self.name, old_dtype, dtype))
"matched.".format(self.name, old_dtype,
dtype))
if lod_level is not None:
if is_new_var:
self.desc.set_lod_level(lod_level)
else:
if lod_level != self.lod_level:
raise ValueError(
"MpcVariable {0} has been created before. "
raise ValueError("MpcVariable {0} has been created before. "
"The previous lod_level is {1}; the new "
"lod_level is {2}. They are not "
"matched".format(self.name, self.lod_level, lod_level))
"matched".format(self.name, self.lod_level,
lod_level))
if persistable is not None:
if is_new_var:
self.desc.set_persistable(persistable)
......@@ -156,8 +155,7 @@ class MpcParameter(MpcVariable):
if len(shape) == 0:
raise ValueError(
"The dimensions of shape for MpcParameter must be greater than 0"
)
"The dimensions of shape for MpcParameter must be greater than 0")
for each in shape:
if each < 0:
......@@ -175,8 +173,7 @@ class MpcParameter(MpcVariable):
**kwargs)
self.trainable = kwargs.get('trainable', True)
self.optimize_attr = kwargs.get('optimize_attr',
{'learning_rate': 1.0})
self.optimize_attr = kwargs.get('optimize_attr', {'learning_rate': 1.0})
self.regularizer = kwargs.get('regularizer', None)
......@@ -203,8 +200,8 @@ class MpcParameter(MpcVariable):
additional_attr = ("trainable", "optimize_attr", "regularizer",
"gradient_clip_attr", "do_model_average")
for attr_name in additional_attr:
res_str += "%s: %s\n" % (
attr_name, cpt.to_text(getattr(self, attr_name)))
res_str += "%s: %s\n" % (attr_name,
cpt.to_text(getattr(self, attr_name)))
else:
res_str = MpcVariable.to_string(self, throw_on_error, False)
return res_str
......@@ -245,8 +242,7 @@ def create_mpc_parameter(block, *args, **kwargs):
init_ops_len = len(init_ops)
if init_ops_len > 1:
raise RuntimeError("mpc_param " + mpc_param.name +
" is inited by multiple init ops " + str(
init_ops))
" is inited by multiple init ops " + str(init_ops))
elif init_ops_len == 1:
# TODO(Paddle 1.7): already inited, do nothing, should log a warning
pass
......@@ -272,7 +268,6 @@ def create_mpc_var(block, *args, **kwargs):
kwargs['initializer'](var, block)
return var
def is_mpc_parameter(var):
"""
Check whether the given variable is an instance of MpcParameter.
......@@ -282,4 +277,13 @@ def is_mpc_parameter(var):
bool: True if the given `var` is an instance of Parameter,
False if not.
"""
return isinstance(var, MpcParameter)
return type(var) == MpcParameter
def check_mpc_variable_and_dtype(input,
input_name,
expected_dtype,
op_name,
extra_message=''):
check_type(input, input_name, MpcVariable, op_name, extra_message)
check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message)
......@@ -14,9 +14,10 @@
"""
basic mpc op layers.
"""
from paddle.fluid.data_feeder import check_type_and_dtype
from paddle.fluid.data_feeder import check_variable_and_dtype
from ..framework import MpcVariable
from ..framework import check_mpc_variable_and_dtype
from ..mpc_layer_helper import MpcLayerHelper
__all__ = [
......@@ -32,8 +33,8 @@ def _elementwise_op(helper):
assert x is not None, 'x cannot be None in {}'.format(op_type)
assert y is not None, 'y cannot be None in {}'.format(op_type)
check_type_and_dtype(x, 'x', MpcVariable, ['int64'], op_type)
check_type_and_dtype(y, 'y', MpcVariable, ['int64'], op_type)
check_mpc_variable_and_dtype(x, 'x', ['int64'], op_type)
check_mpc_variable_and_dtype(y, 'y', ['int64'], op_type)
axis = helper.kwargs.get('axis', -1)
use_mkldnn = helper.kwargs.get('use_mkldnn', False)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,7 +14,6 @@
"""
mpc math compare layers.
"""
from paddle.fluid.data_feeder import check_type_and_dtype
from ..framework import MpcVariable
from ..mpc_layer_helper import MpcLayerHelper
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,9 +14,9 @@
"""
mpc math op layers.
"""
from paddle.fluid.data_feeder import check_type_and_dtype
from ..framework import MpcVariable
from ..framework import check_mpc_variable_and_dtype
from ..mpc_layer_helper import MpcLayerHelper
__all__ = [
......@@ -39,7 +39,7 @@ def mean(x, name=None):
Examples: todo
"""
helper = MpcLayerHelper("mean", **locals())
check_type_and_dtype(x, 'x', MpcVariable, ['int64'], 'mean')
check_mpc_variable_and_dtype(x, 'x', ['int64'], 'mean')
if name is None:
out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype)
else:
......@@ -64,7 +64,7 @@ def square(x, name=None):
Examples: todo
"""
helper = MpcLayerHelper("square", **locals())
check_type_and_dtype(x, 'x', MpcVariable, ['int64'], 'square')
check_mpc_variable_and_dtype(x, 'x', ['int64'], 'square')
if name is None:
out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype)
else:
......@@ -89,8 +89,7 @@ def sum(x):
Examples: todo
"""
helper = MpcLayerHelper("sum", **locals())
out = helper.create_mpc_variable_for_type_inference(
dtype=helper.input_dtype('x'))
out = helper.create_mpc_variable_for_type_inference(dtype=helper.input_dtype('x'))
helper.append_op(
type="mpc_sum",
inputs={"X": x},
......@@ -116,16 +115,14 @@ def square_error_cost(input, label):
Examples: todo
"""
helper = MpcLayerHelper('square_error_cost', **locals())
minus_out = helper.create_mpc_variable_for_type_inference(
dtype=input.dtype)
minus_out = helper.create_mpc_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='mpc_elementwise_sub',
inputs={'X': [input],
'Y': [label]},
outputs={'Out': [minus_out]})
square_out = helper.create_mpc_variable_for_type_inference(
dtype=input.dtype)
square_out = helper.create_mpc_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='mpc_square',
inputs={'X': [minus_out]},
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,12 +14,14 @@
"""
mpc matrix op layers.
"""
from paddle.fluid.data_feeder import check_type_and_dtype
from ..framework import MpcVariable
from ..framework import check_mpc_variable_and_dtype
from ..mpc_layer_helper import MpcLayerHelper
__all__ = ['mul', ]
__all__ = [
'mul',
]
def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
......@@ -66,8 +68,8 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
}
helper = MpcLayerHelper("mul", **locals())
check_type_and_dtype(x, 'x', MpcVariable, ['int64'], 'mul')
check_type_and_dtype(y, 'y', MpcVariable, ['int64'], 'mul')
check_mpc_variable_and_dtype(x, 'x', ['int64'], 'mul')
check_mpc_variable_and_dtype(y, 'y', ['int64'], 'mul')
if name is None:
out = helper.create_mpc_variable_for_type_inference(dtype=x.dtype)
else:
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -17,9 +17,9 @@ mpc ml op layers.
from functools import reduce
from paddle.fluid.data_feeder import check_type, check_dtype
from paddle.fluid.data_feeder import check_type_and_dtype
import numpy
from ..framework import MpcVariable
from ..framework import check_mpc_variable_and_dtype
from ..mpc_layer_helper import MpcLayerHelper
__all__ = [
......@@ -30,9 +30,6 @@ __all__ = [
]
# add softmax, relu
def fc(input,
size,
num_flatten_dims=1,
......@@ -186,8 +183,7 @@ def softmax(input, use_cudnn=False, name=None, axis=-1):
"""
attrs = {"axis": axis, "use_cudnn": use_cudnn}
helper = MpcLayerHelper('softmax', **locals())
check_type_and_dtype(input, 'input', MpcVariable,
['float16', 'float32', 'float64'], 'softmax')
check_mpc_variable_and_dtype(input, 'input', ['int64'], 'softmax')
dtype = helper.input_dtype()
mpc_softmax_out = helper.create_mpc_variable_for_type_inference(dtype)
......@@ -226,7 +222,9 @@ def relu(input, name=None):
dtype = helper.input_dtype(input_param_name='input')
out = helper.create_mpc_variable_for_type_inference(dtype)
helper.append_op(
type="mpc_relu", inputs={"X": input}, outputs={"Y": out})
type="mpc_relu",
inputs={"X": input},
outputs={"Y": out})
return out
......
......@@ -32,7 +32,6 @@ def monkey_patch_mpc_variable():
Monkey patch for operator overloading.
:return:
"""
def unique_tmp_name():
"""
Generate temp name for variable.
......@@ -80,7 +79,9 @@ def monkey_patch_mpc_variable():
tmp_name = unique_tmp_name()
return block.create_var(name=tmp_name, dtype=dtype)
def _elemwise_method_creator_(method_name, op_type, reverse=False):
def _elemwise_method_creator_(method_name,
op_type,
reverse=False):
"""
Operator overloading for different method.
:param method_name: the name of operator which is overloaded.
......@@ -88,19 +89,16 @@ def monkey_patch_mpc_variable():
:param reverse:
:return:
"""
def __impl__(self, other_var):
lhs_dtype = safe_get_dtype(self)
if method_name in compare_ops:
if not isinstance(other_var, Variable):
raise NotImplementedError(
"Unsupported data type of {} for compare operations."
raise NotImplementedError("Unsupported data type of {} for compare operations."
.format(other_var.name))
else:
if not isinstance(other_var, MpcVariable):
raise NotImplementedError(
"Unsupported data type of {}.".format(other_var.name))
raise NotImplementedError("Unsupported data type of {}.".format(other_var.name))
rhs_dtype = safe_get_dtype(other_var)
if reverse:
......@@ -111,8 +109,7 @@ def monkey_patch_mpc_variable():
if method_name in compare_ops:
out = create_new_tmp_var(current_block(self), dtype=rhs_dtype)
else:
out = create_new_tmp_mpc_var(
current_block(self), dtype=lhs_dtype)
out = create_new_tmp_mpc_var(current_block(self), dtype=lhs_dtype)
# out = create_new_tmp_mpc_var(current_block(self), dtype=lhs_dtype)
......@@ -179,10 +176,11 @@ def monkey_patch_mpc_variable():
("__lt__", "mpc_less_than", False),
("__le__", "mpc_less_equal", False),
("__gt__", "mpc_greater_than", False),
("__ge__", "mpc_greater_equal", False)):
("__ge__", "mpc_greater_equal", False)
):
# Not support computation between MpcVariable and scalar.
setattr(MpcVariable, method_name,
setattr(MpcVariable,
method_name,
_elemwise_method_creator_(method_name, op_type, reverse)
if method_name in supported_mpc_ops else announce_not_impl)
# MpcVariable.astype = astype
......@@ -34,7 +34,7 @@ def python_version():
max_version, mid_version, min_version = python_version()
REQUIRED_PACKAGES = [
'six >= 1.10.0', 'protobuf >= 3.1.0', 'paddlepaddle == 1.6.3', 'paddlepaddle-gpu >= 1.8'
'six >= 1.10.0', 'protobuf >= 3.1.0', 'paddlepaddle == 1.8.0', 'paddlepaddle-gpu >= 1.8'
]
if max_version < 3:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册