Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleFL
提交
a1a9bf6b
P
PaddleFL
项目概览
PaddlePaddle
/
PaddleFL
通知
35
Star
5
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
6
列表
看板
标记
里程碑
合并请求
4
Wiki
3
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleFL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
6
Issue
6
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
3
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a1a9bf6b
编写于
5月 17, 2020
作者:
Q
Qinghe JING
提交者:
GitHub
5月 17, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #70 from jhjiangcs/smc-611
improve code to support PaddlePaddle1.8.0.
上级
653068b8
48813130
变更
40
展开全部
显示空白变更内容
内联
并排
Showing
40 changed file
with
5234 addition
and
5103 deletion
+5234
-5103
CMakeLists.txt
CMakeLists.txt
+2
-2
cmake/external/gtest.cmake
cmake/external/gtest.cmake
+2
-2
cmake/external/hiredis.cmake
cmake/external/hiredis.cmake
+2
-2
cmake/external/openssl.cmake
cmake/external/openssl.cmake
+2
-2
cmake/external/pybind11.cmake
cmake/external/pybind11.cmake
+2
-2
cmake/generic.cmake
cmake/generic.cmake
+2
-2
cmake/third_party.cmake
cmake/third_party.cmake
+2
-2
core/paddlefl_mpc/operators/mpc_compare_op.cc
core/paddlefl_mpc/operators/mpc_compare_op.cc
+72
-84
core/paddlefl_mpc/operators/mpc_compare_op.h
core/paddlefl_mpc/operators/mpc_compare_op.h
+40
-49
core/paddlefl_mpc/operators/mpc_elementwise_add_op.cc
core/paddlefl_mpc/operators/mpc_elementwise_add_op.cc
+82
-88
core/paddlefl_mpc/operators/mpc_elementwise_add_op.h
core/paddlefl_mpc/operators/mpc_elementwise_add_op.h
+158
-161
core/paddlefl_mpc/operators/mpc_elementwise_sub_op.cc
core/paddlefl_mpc/operators/mpc_elementwise_sub_op.cc
+79
-87
core/paddlefl_mpc/operators/mpc_elementwise_sub_op.h
core/paddlefl_mpc/operators/mpc_elementwise_sub_op.h
+39
-41
core/paddlefl_mpc/operators/mpc_init_op.cc
core/paddlefl_mpc/operators/mpc_init_op.cc
+56
-52
core/paddlefl_mpc/operators/mpc_mean_op.cc
core/paddlefl_mpc/operators/mpc_mean_op.cc
+59
-57
core/paddlefl_mpc/operators/mpc_mean_op.h
core/paddlefl_mpc/operators/mpc_mean_op.h
+41
-45
core/paddlefl_mpc/operators/mpc_mul_op.cc
core/paddlefl_mpc/operators/mpc_mul_op.cc
+177
-174
core/paddlefl_mpc/operators/mpc_mul_op.h
core/paddlefl_mpc/operators/mpc_mul_op.h
+171
-187
core/paddlefl_mpc/operators/mpc_op.h
core/paddlefl_mpc/operators/mpc_op.h
+30
-28
core/paddlefl_mpc/operators/mpc_relu_op.cc
core/paddlefl_mpc/operators/mpc_relu_op.cc
+28
-25
core/paddlefl_mpc/operators/mpc_relu_op.h
core/paddlefl_mpc/operators/mpc_relu_op.h
+18
-24
core/paddlefl_mpc/operators/mpc_sgd_op.cc
core/paddlefl_mpc/operators/mpc_sgd_op.cc
+75
-75
core/paddlefl_mpc/operators/mpc_sgd_op.h
core/paddlefl_mpc/operators/mpc_sgd_op.h
+51
-57
core/paddlefl_mpc/operators/mpc_sigmoid_cross_entropy_with_logits_op.cc
...mpc/operators/mpc_sigmoid_cross_entropy_with_logits_op.cc
+10
-12
core/paddlefl_mpc/operators/mpc_square_op.cc
core/paddlefl_mpc/operators/mpc_square_op.cc
+56
-55
core/paddlefl_mpc/operators/mpc_square_op.h
core/paddlefl_mpc/operators/mpc_square_op.h
+34
-36
core/paddlefl_mpc/operators/mpc_sum_op.cc
core/paddlefl_mpc/operators/mpc_sum_op.cc
+118
-115
core/paddlefl_mpc/operators/mpc_sum_op.h
core/paddlefl_mpc/operators/mpc_sum_op.h
+62
-67
core/privc3/boolean_tensor_test.cc
core/privc3/boolean_tensor_test.cc
+1213
-1073
core/privc3/fixedpoint_tensor_test.cc
core/privc3/fixedpoint_tensor_test.cc
+2279
-2227
core/privc3/paddle_tensor_test.cc
core/privc3/paddle_tensor_test.cc
+179
-176
python/paddle_fl/mpc/framework.py
python/paddle_fl/mpc/framework.py
+26
-22
python/paddle_fl/mpc/layers/__init__.py
python/paddle_fl/mpc/layers/__init__.py
+1
-1
python/paddle_fl/mpc/layers/basic.py
python/paddle_fl/mpc/layers/basic.py
+4
-3
python/paddle_fl/mpc/layers/compare.py
python/paddle_fl/mpc/layers/compare.py
+1
-2
python/paddle_fl/mpc/layers/math.py
python/paddle_fl/mpc/layers/math.py
+8
-11
python/paddle_fl/mpc/layers/matrix.py
python/paddle_fl/mpc/layers/matrix.py
+12
-10
python/paddle_fl/mpc/layers/ml.py
python/paddle_fl/mpc/layers/ml.py
+6
-8
python/paddle_fl/mpc/layers/mpc_math_op_patch.py
python/paddle_fl/mpc/layers/mpc_math_op_patch.py
+34
-36
python/setup.py
python/setup.py
+1
-1
未找到文件。
CMakeLists.txt
浏览文件 @
a1a9bf6b
...
@@ -34,8 +34,8 @@ execute_process(COMMAND ${PYTHON} -c "import paddle;print(paddle.version.full_ve
...
@@ -34,8 +34,8 @@ execute_process(COMMAND ${PYTHON} -c "import paddle;print(paddle.version.full_ve
RESULT_VARIABLE ret OUTPUT_VARIABLE paddle_version OUTPUT_STRIP_TRAILING_WHITESPACE
)
RESULT_VARIABLE ret OUTPUT_VARIABLE paddle_version OUTPUT_STRIP_TRAILING_WHITESPACE
)
if
(
NOT ret
)
if
(
NOT ret
)
if
(
NOT
${
paddle_version
}
STREQUAL
"1.
6.3
"
)
if
(
NOT
${
paddle_version
}
STREQUAL
"1.
8.0
"
)
message
(
FATAL_ERROR
"Paddle installation of 1.
6.3
is required but
${
paddle_version
}
is found"
)
message
(
FATAL_ERROR
"Paddle installation of 1.
8.0
is required but
${
paddle_version
}
is found"
)
endif
()
endif
()
else
()
else
()
message
(
FATAL_ERROR
"Could not get paddle version."
)
message
(
FATAL_ERROR
"Could not get paddle version."
)
...
...
cmake/external/gtest.cmake
浏览文件 @
a1a9bf6b
#
Copyright (c) 2019
PaddlePaddle Authors. All Rights Reserved.
#
Copyright (c) 2020
PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
cmake/external/hiredis.cmake
浏览文件 @
a1a9bf6b
cmake/external/openssl.cmake
浏览文件 @
a1a9bf6b
cmake/external/pybind11.cmake
浏览文件 @
a1a9bf6b
cmake/generic.cmake
浏览文件 @
a1a9bf6b
cmake/third_party.cmake
浏览文件 @
a1a9bf6b
core/paddlefl_mpc/operators/mpc_compare_op.cc
浏览文件 @
a1a9bf6b
/
/
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/
*
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.
limitations under the License. */
#include "mpc_compare_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_compare_op.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -25,16 +25,16 @@ class MpcCompareOp : public framework::OperatorWithKernel {
...
@@ -25,16 +25,16 @@ class MpcCompareOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
NotFound
(
ctx
->
HasInput
(
"X"
),
true
,
"Input(X) of MpcCompareOp should not be null."
));
platform
::
errors
::
NotFound
(
"Input(X) of MpcCompareOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Y"
),
true
,
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
NotFound
(
ctx
->
HasInput
(
"Y"
),
true
,
"Input(Y) of MpcCompareOp should not be null."
));
platform
::
errors
::
NotFound
(
"Input(Y) of MpcCompareOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
NotFound
(
ctx
->
HasOutput
(
"Out"
),
true
,
"Output(Out) of MpcCompareOp should not be null."
));
platform
::
errors
::
NotFound
(
"Output(Out) of MpcCompareOp should not be null."
));
auto
dim_x
=
ctx
->
GetInputDim
(
"X"
);
auto
dim_x
=
ctx
->
GetInputDim
(
"X"
);
auto
dim_y
=
ctx
->
GetInputDim
(
"Y"
);
auto
dim_y
=
ctx
->
GetInputDim
(
"Y"
);
...
@@ -45,11 +45,12 @@ public:
...
@@ -45,11 +45,12 @@ public:
ctx
->
ShareLoD
(
"Y"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"Y"
,
/*->*/
"Out"
);
}
}
framework
::
OpKernelType
framework
::
OpKernelType
GetExpectedKernelType
(
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
}
};
};
class
MpcCompareOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
MpcCompareOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
@@ -68,40 +69,27 @@ MPC Compare Operator.
...
@@ -68,40 +69,27 @@ MPC Compare Operator.
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
mpc_greater_than
,
ops
::
MpcCompareOp
,
REGISTER_OP_WITHOUT_GRADIENT
(
mpc_greater_than
,
ops
::
MpcCompareOp
,
ops
::
MpcCompareOpMaker
);
ops
::
MpcCompareOpMaker
);
REGISTER_OP_CPU_KERNEL
(
mpc_greater_than
,
REGISTER_OP_CPU_KERNEL
(
ops
::
MpcCompareOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
,
ops
::
MpcGreaterThanFunctor
>
);
mpc_greater_than
,
ops
::
MpcCompareOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
,
REGISTER_OP_WITHOUT_GRADIENT
(
mpc_greater_equal
,
ops
::
MpcCompareOp
,
ops
::
MpcCompareOpMaker
);
ops
::
MpcGreaterThanFunctor
>
);
REGISTER_OP_CPU_KERNEL
(
mpc_greater_equal
,
ops
::
MpcCompareOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
,
ops
::
MpcGreaterEqualFunctor
>
);
REGISTER_OP_WITHOUT_GRADIENT
(
mpc_greater_equal
,
ops
::
MpcCompareOp
,
ops
::
MpcCompareOpMaker
);
REGISTER_OP_WITHOUT_GRADIENT
(
mpc_less_than
,
ops
::
MpcCompareOp
,
ops
::
MpcCompareOpMaker
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
mpc_less_than
,
mpc_greater_equal
,
ops
::
MpcCompareOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
,
ops
::
MpcLessThanFunctor
>
);
ops
::
MpcCompareOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
,
ops
::
MpcGreaterEqualFunctor
>
);
REGISTER_OP_WITHOUT_GRADIENT
(
mpc_less_equal
,
ops
::
MpcCompareOp
,
ops
::
MpcCompareOpMaker
);
REGISTER_OP_CPU_KERNEL
(
mpc_less_equal
,
REGISTER_OP_WITHOUT_GRADIENT
(
mpc_less_than
,
ops
::
MpcCompareOp
,
ops
::
MpcCompareOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
,
ops
::
MpcLessEqualFunctor
>
);
ops
::
MpcCompareOpMaker
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_WITHOUT_GRADIENT
(
mpc_equal
,
ops
::
MpcCompareOp
,
ops
::
MpcCompareOpMaker
);
mpc_less_than
,
ops
::
MpcCompareOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
REGISTER_OP_CPU_KERNEL
(
mpc_equal
,
int64_t
,
ops
::
MpcLessThanFunctor
>
);
ops
::
MpcCompareOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
,
ops
::
MpcEqualFunctor
>
);
REGISTER_OP_WITHOUT_GRADIENT
(
mpc_less_equal
,
ops
::
MpcCompareOp
,
REGISTER_OP_WITHOUT_GRADIENT
(
mpc_not_equal
,
ops
::
MpcCompareOp
,
ops
::
MpcCompareOpMaker
);
ops
::
MpcCompareOpMaker
);
REGISTER_OP_CPU_KERNEL
(
mpc_not_equal
,
REGISTER_OP_CPU_KERNEL
(
ops
::
MpcCompareOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
,
ops
::
MpcNotEqualFunctor
>
);
mpc_less_equal
,
ops
::
MpcCompareOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
,
ops
::
MpcLessEqualFunctor
>
);
REGISTER_OP_WITHOUT_GRADIENT
(
mpc_equal
,
ops
::
MpcCompareOp
,
ops
::
MpcCompareOpMaker
);
REGISTER_OP_CPU_KERNEL
(
mpc_equal
,
ops
::
MpcCompareOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
,
ops
::
MpcEqualFunctor
>
);
REGISTER_OP_WITHOUT_GRADIENT
(
mpc_not_equal
,
ops
::
MpcCompareOp
,
ops
::
MpcCompareOpMaker
);
REGISTER_OP_CPU_KERNEL
(
mpc_not_equal
,
ops
::
MpcCompareOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
,
ops
::
MpcNotEqualFunctor
>
);
core/paddlefl_mpc/operators/mpc_compare_op.h
浏览文件 @
a1a9bf6b
/
/
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/
*
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.uage governing permissions and
limitations under the License. */
#pragma once
#pragma once
#include "mpc_op.h"
#include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include <math.h>
#include <type_traits>
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -25,50 +22,44 @@ using Tensor = framework::Tensor;
...
@@ -25,50 +22,44 @@ using Tensor = framework::Tensor;
struct
MpcGreaterThanFunctor
{
struct
MpcGreaterThanFunctor
{
void
Run
(
const
Tensor
*
in_x_t
,
const
Tensor
*
in_y_t
,
Tensor
*
out_t
)
{
void
Run
(
const
Tensor
*
in_x_t
,
const
Tensor
*
in_y_t
,
Tensor
*
out_t
)
{
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
gt
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
gt
(
in_x_t
,
in_y_t
,
out_t
);
in_x_t
,
in_y_t
,
out_t
);
}
}
};
};
struct
MpcGreaterEqualFunctor
{
struct
MpcGreaterEqualFunctor
{
void
Run
(
const
Tensor
*
in_x_t
,
const
Tensor
*
in_y_t
,
Tensor
*
out_t
)
{
void
Run
(
const
Tensor
*
in_x_t
,
const
Tensor
*
in_y_t
,
Tensor
*
out_t
)
{
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
geq
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
geq
(
in_x_t
,
in_y_t
,
out_t
);
in_x_t
,
in_y_t
,
out_t
);
}
}
};
};
struct
MpcLessThanFunctor
{
struct
MpcLessThanFunctor
{
void
Run
(
const
Tensor
*
in_x_t
,
const
Tensor
*
in_y_t
,
Tensor
*
out_t
)
{
void
Run
(
const
Tensor
*
in_x_t
,
const
Tensor
*
in_y_t
,
Tensor
*
out_t
)
{
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
lt
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
lt
(
in_x_t
,
in_y_t
,
out_t
);
in_x_t
,
in_y_t
,
out_t
);
}
}
};
};
struct
MpcLessEqualFunctor
{
struct
MpcLessEqualFunctor
{
void
Run
(
const
Tensor
*
in_x_t
,
const
Tensor
*
in_y_t
,
Tensor
*
out_t
)
{
void
Run
(
const
Tensor
*
in_x_t
,
const
Tensor
*
in_y_t
,
Tensor
*
out_t
)
{
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
leq
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
leq
(
in_x_t
,
in_y_t
,
out_t
);
in_x_t
,
in_y_t
,
out_t
);
}
}
};
};
struct
MpcEqualFunctor
{
struct
MpcEqualFunctor
{
void
Run
(
const
Tensor
*
in_x_t
,
const
Tensor
*
in_y_t
,
Tensor
*
out_t
)
{
void
Run
(
const
Tensor
*
in_x_t
,
const
Tensor
*
in_y_t
,
Tensor
*
out_t
)
{
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
eq
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
eq
(
in_x_t
,
in_y_t
,
out_t
);
in_x_t
,
in_y_t
,
out_t
);
}
}
};
};
struct
MpcNotEqualFunctor
{
struct
MpcNotEqualFunctor
{
void
Run
(
const
Tensor
*
in_x_t
,
const
Tensor
*
in_y_t
,
Tensor
*
out_t
)
{
void
Run
(
const
Tensor
*
in_x_t
,
const
Tensor
*
in_y_t
,
Tensor
*
out_t
)
{
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
neq
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
neq
(
in_x_t
,
in_y_t
,
out_t
);
in_x_t
,
in_y_t
,
out_t
);
}
}
};
};
template
<
typename
DeviceContext
,
typename
T
,
typename
Functor
>
template
<
typename
DeviceContext
,
typename
T
,
typename
Functor
>
class
MpcCompareOpKernel
:
public
MpcOpKernel
<
T
>
{
class
MpcCompareOpKernel
:
public
MpcOpKernel
<
T
>
{
public:
public:
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in_x_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
in_x_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
in_y_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Y"
);
auto
*
in_y_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Y"
);
auto
*
out_t
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
*
out_t
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
...
...
core/paddlefl_mpc/operators/mpc_elementwise_add_op.cc
浏览文件 @
a1a9bf6b
/
/
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/
*
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.
limitations under the License. */
#include "mpc_elementwise_add_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_elementwise_add_op.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -24,38 +24,33 @@ class MpcElementwiseAddOp : public framework::OperatorWithKernel {
...
@@ -24,38 +24,33 @@ class MpcElementwiseAddOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
ctx
->
HasInput
(
"X"
),
true
,
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
"Input(X) of MpcElementwiseAddOp should not be null."
));
"Input(X) of MpcElementwiseAddOp should not be null."
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Y"
),
true
,
ctx
->
HasInput
(
"Y"
),
true
,
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
"Input(Y) of MpcElementwiseAddOp should not be null."
));
"Input(Y) of MpcElementwiseAddOp should not be null."
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
ctx
->
HasOutput
(
"Out"
),
true
,
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
"Output(Out) of MpcElementwiseAddOp should not be null."
));
"Output(Out) of MpcElementwiseAddOp should not be null."
));
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
ctx
->
GetInputDim
(
"X"
).
size
(),
ctx
->
GetInputDim
(
"Y"
).
size
(),
ctx
->
GetInputDim
(
"X"
).
size
(),
ctx
->
GetInputDim
(
"Y"
).
size
(),
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The dimensions of X should be greater than the dimensions of Y. "
"The dimensions of X should be greater than the dimensions of Y. "
"But received the dimensions of X is [%s], the dimensions of Y is "
"But received the dimensions of X is [%s], the dimensions of Y is [%s]"
,
"[%s]"
,
ctx
->
GetInputDim
(
"X"
),
ctx
->
GetInputDim
(
"Y"
)));
ctx
->
GetInputDim
(
"X"
),
ctx
->
GetInputDim
(
"Y"
)));
ctx
->
ShareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
}
};
};
class
MpcElementwiseAddOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
MpcElementwiseAddOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
void
Make
()
override
{
void
Make
()
override
{
AddInput
(
"X"
,
AddInput
(
"X"
,
"(Tensor), The first input tensor of mpc elementwise add op."
);
"(Tensor), The first input tensor of mpc elementwise add op."
);
AddInput
(
"Y"
,
"(Tensor), The second input tensor of mpc elementwise add op."
);
AddInput
(
"Y"
,
"(Tensor), The second input tensor of mpc elementwise add op."
);
AddOutput
(
"Out"
,
"(Tensor), The output tensor of mpc elementwise add op."
);
AddOutput
(
"Out"
,
"(Tensor), The output tensor of mpc elementwise add op."
);
AddAttr
<
int
>
(
"axis"
,
AddAttr
<
int
>
(
"axis"
,
"(int, default -1). If X.dimension != Y.dimension,"
"(int, default -1). If X.dimension != Y.dimension,"
...
@@ -92,24 +87,23 @@ public:
...
@@ -92,24 +87,23 @@ public:
ctx
->
ShareLoD
(
"Y"
,
/*->*/
y_grad_name
);
ctx
->
ShareLoD
(
"Y"
,
/*->*/
y_grad_name
);
}
}
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
class
MpcElementwiseAddOpGradMaker
:
public
framework
::
SingleGradOp
DescMaker
{
class
MpcElementwiseAddOpGradMaker
:
public
framework
::
SingleGradOp
Maker
<
T
>
{
public:
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDesc
Maker
;
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOp
Maker
;
protected:
protected:
std
::
unique_ptr
<
T
>
Apply
()
const
override
{
void
Apply
(
GradOpPtr
<
T
>
grad
)
const
override
{
std
::
unique_ptr
<
T
>
retv
(
new
T
());
grad
->
SetType
(
"mpc_elementwise_add_grad"
);
retv
->
SetType
(
"mpc_elementwise_add_grad"
);
grad
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
retv
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
grad
->
SetInput
(
"Y"
,
this
->
Input
(
"Y"
));
retv
->
SetInput
(
"Y"
,
this
->
Input
(
"Y"
));
grad
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
retv
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
grad
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
retv
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
grad
->
SetOutput
(
framework
::
GradVarName
(
"Y"
),
this
->
InputGrad
(
"Y"
));
retv
->
SetOutput
(
framework
::
GradVarName
(
"Y"
),
this
->
InputGrad
(
"Y"
));
grad
->
SetAttrMap
(
this
->
Attrs
());
retv
->
SetAttrMap
(
this
->
Attrs
());
return
retv
;
}
}
};
};
...
@@ -127,6 +121,6 @@ REGISTER_OP_CPU_KERNEL(
...
@@ -127,6 +121,6 @@ REGISTER_OP_CPU_KERNEL(
mpc_elementwise_add
,
mpc_elementwise_add
,
ops
::
MpcElementwiseAddKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
ops
::
MpcElementwiseAddKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
REGISTER_OP_CPU_KERNEL
(
mpc_elementwise_add_grad
,
REGISTER_OP_CPU_KERNEL
(
ops
::
MpcElementwiseAddGradKernel
<
mpc_elementwise_add_grad
,
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
ops
::
MpcElementwiseAddGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
core/paddlefl_mpc/operators/mpc_elementwise_add_op.h
浏览文件 @
a1a9bf6b
/
/
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/
*
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.
limitations under the License. */
// This op is different with elementwise_add of PaddlePaddle.
// This op is different with elementwise_add of PaddlePaddle.
// We only consider that the dimensions of X is equal with the dimensions of Y.
// We only consider that the dimensions of X is equal with the dimensions of Y.
...
@@ -18,7 +18,6 @@
...
@@ -18,7 +18,6 @@
#pragma once
#pragma once
#include "mpc_op.h"
#include "mpc_op.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/fluid/platform/transform.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -26,12 +25,12 @@ namespace operators {
...
@@ -26,12 +25,12 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
// paddle/fluid/operators/elementwise/elementwise_op_function.h
// paddle/fluid/operators/elementwise/elementwise_op_function.h
template
<
typename
T
,
typename
DeviceContext
>
class
RowwiseTransformIterator
;
template
<
typename
T
,
typename
DeviceContext
>
class
RowwiseTransformIterator
;
template
<
typename
T
>
template
<
typename
T
>
class
RowwiseTransformIterator
<
T
,
platform
::
CPUDeviceContext
>
class
RowwiseTransformIterator
<
T
,
platform
::
CPUDeviceContext
>
:
public
std
::
iterator
<
std
::
random_access_iterator_tag
,
T
,
std
::
ptrdiff_t
,
:
public
std
::
iterator
<
std
::
random_access_iterator_tag
,
T
,
std
::
ptrdiff_t
,
T
*
,
T
&>
{
T
*
,
T
&>
{
public:
public:
RowwiseTransformIterator
(
const
T
*
ptr
,
int
n
)
:
ptr_
(
ptr
),
i_
(
0
),
n_
(
n
)
{}
RowwiseTransformIterator
(
const
T
*
ptr
,
int
n
)
:
ptr_
(
ptr
),
i_
(
0
),
n_
(
n
)
{}
...
@@ -54,13 +53,11 @@ public:
...
@@ -54,13 +53,11 @@ public:
return
*
this
;
return
*
this
;
}
}
bool
operator
==
(
const
RowwiseTransformIterator
<
T
,
platform
::
CPUDeviceContext
>
bool
operator
==
(
const
RowwiseTransformIterator
<
T
,
platform
::
CPUDeviceContext
>
&
rhs
)
const
{
&
rhs
)
const
{
return
(
ptr_
+
i_
)
==
&
(
*
rhs
);
return
(
ptr_
+
i_
)
==
&
(
*
rhs
);
}
}
bool
operator
!=
(
const
RowwiseTransformIterator
<
T
,
platform
::
CPUDeviceContext
>
bool
operator
!=
(
const
RowwiseTransformIterator
<
T
,
platform
::
CPUDeviceContext
>
&
rhs
)
const
{
&
rhs
)
const
{
return
(
ptr_
+
i_
)
!=
&
(
*
rhs
);
return
(
ptr_
+
i_
)
!=
&
(
*
rhs
);
}
}
...
@@ -72,15 +69,15 @@ private:
...
@@ -72,15 +69,15 @@ private:
int64_t
n_
;
int64_t
n_
;
};
};
template
<
typename
T
>
struct
AddFunctor
{
template
<
typename
T
>
struct
AddFunctor
{
inline
HOSTDEVICE
T
operator
()(
T
x
,
T
y
)
{
return
x
+
y
;
}
inline
HOSTDEVICE
T
operator
()(
T
x
,
T
y
)
{
return
x
+
y
;
}
};
};
struct
GetMidDims
{
struct
GetMidDims
{
inline
HOSTDEVICE
void
operator
()(
const
framework
::
DDim
&
x_dims
,
inline
HOSTDEVICE
void
operator
()(
const
framework
::
DDim
&
x_dims
,
const
framework
::
DDim
&
y_dims
,
const
framework
::
DDim
&
y_dims
,
const
int
axis
,
const
int
axis
,
int
*
pre
,
int
*
n
,
int
*
pre
,
int
*
n
,
int
*
post
)
{
int
*
post
)
{
*
pre
=
1
;
*
pre
=
1
;
*
n
=
1
;
*
n
=
1
;
*
post
=
1
;
*
post
=
1
;
...
@@ -105,18 +102,17 @@ const size_t SHARE_NUM = 2;
...
@@ -105,18 +102,17 @@ const size_t SHARE_NUM = 2;
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
MpcElementwiseAddKernel
:
public
MpcOpKernel
<
T
>
{
class
MpcElementwiseAddKernel
:
public
MpcOpKernel
<
T
>
{
public:
public:
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in_x_t
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
in_x_t
=
ctx
.
Input
<
framework
::
LoD
Tensor
>
(
"X"
);
auto
*
in_y_t
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
in_y_t
=
ctx
.
Input
<
framework
::
LoD
Tensor
>
(
"Y"
);
auto
*
out_t
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
out_t
=
ctx
.
Output
<
framework
::
LoD
Tensor
>
(
"Out"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
out
=
out_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
out
=
out_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
in_x_t
->
dims
()
==
in_y_t
->
dims
())
{
if
(
in_x_t
->
dims
()
==
in_y_t
->
dims
())
{
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
add
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
add
(
in_x_t
,
in_y_t
,
out_t
);
in_x_t
,
in_y_t
,
out_t
);
}
else
{
}
else
{
Tensor
in_x_t_slice
;
Tensor
in_x_t_slice
;
Tensor
in_y_t_slice
;
Tensor
in_y_t_slice
;
...
@@ -137,8 +133,8 @@ public:
...
@@ -137,8 +133,8 @@ public:
int
pre
,
n
,
post
;
int
pre
,
n
,
post
;
GetMidDims
get_mid_dims
;
GetMidDims
get_mid_dims
;
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
);
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
post
,
1
,
post
,
1
,
"post should be equal 1, but received post is [%s]"
,
post
);
"post should be equal 1, but received post is [%s]"
,
post
);
auto
x_
=
in_x_t_slice
.
data
<
T
>
();
auto
x_
=
in_x_t_slice
.
data
<
T
>
();
auto
y_
=
in_y_t_slice
.
data
<
T
>
();
auto
y_
=
in_y_t_slice
.
data
<
T
>
();
...
@@ -146,8 +142,8 @@ public:
...
@@ -146,8 +142,8 @@ public:
auto
nx_
=
in_x_t_slice
.
numel
();
auto
nx_
=
in_x_t_slice
.
numel
();
paddle
::
platform
::
Transform
<
DeviceContext
>
trans
;
paddle
::
platform
::
Transform
<
DeviceContext
>
trans
;
trans
(
ctx
.
template
device_context
<
DeviceContext
>(),
x_
,
x_
+
nx_
,
trans
(
ctx
.
template
device_context
<
DeviceContext
>(),
x_
,
x_
+
nx_
,
RowwiseTransformIterator
<
T
,
DeviceContext
>
(
y_
,
n
),
out_
,
RowwiseTransformIterator
<
T
,
DeviceContext
>
(
y_
,
n
)
,
AddFunctor
<
T
>
());
out_
,
AddFunctor
<
T
>
());
}
}
}
}
}
}
...
@@ -159,9 +155,9 @@ public:
...
@@ -159,9 +155,9 @@ public:
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in_x_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
in_x_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
in_y_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Y"
);
auto
*
in_y_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Y"
);
auto
*
dout
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dout
=
ctx
.
Input
<
framework
::
LoD
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dx
=
ctx
.
Output
<
framework
::
LoD
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
*
dy
=
ctx
.
Output
<
framework
::
LoD
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
dout_data
=
dout
->
data
<
T
>
();
auto
dout_data
=
dout
->
data
<
T
>
();
...
@@ -189,8 +185,8 @@ public:
...
@@ -189,8 +185,8 @@ public:
int
pre
,
n
,
post
;
int
pre
,
n
,
post
;
GetMidDims
get_mid_dims
;
GetMidDims
get_mid_dims
;
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
);
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
post
,
1
,
post
,
1
,
"post should be equal 1, but received post is [%s]"
,
post
);
"post should be equal 1, but received post is [%s]"
,
post
);
for
(
size_t
i
=
0
;
i
<
SHARE_NUM
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
SHARE_NUM
;
++
i
)
{
int
y_offset
=
i
*
n
;
int
y_offset
=
i
*
n
;
...
@@ -212,3 +208,4 @@ public:
...
@@ -212,3 +208,4 @@ public:
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
core/paddlefl_mpc/operators/mpc_elementwise_sub_op.cc
浏览文件 @
a1a9bf6b
/
/
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/
*
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.
limitations under the License. */
#include "mpc_elementwise_sub_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_elementwise_sub_op.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -22,25 +22,21 @@ class MpcElementwiseSubOp : public framework::OperatorWithKernel {
...
@@ -22,25 +22,21 @@ class MpcElementwiseSubOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
ctx
->
HasInput
(
"X"
),
true
,
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
"Input(X) of MpcElementwiseSubOp should not be null."
));
"Input(X) of MpcElementwiseSubOp should not be null."
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Y"
),
true
,
ctx
->
HasInput
(
"Y"
),
true
,
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
"Input(Y) of MpcElementwiseSubOp should not be null."
));
"Input(Y) of MpcElementwiseSubOp should not be null."
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
ctx
->
HasOutput
(
"Out"
),
true
,
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
"Output(Out) of MpcElementwiseSubOp should not be null."
));
"Output(Out) of MpcElementwiseSubOp should not be null."
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"X"
),
ctx
->
GetInputDim
(
"Y"
),
ctx
->
GetInputDim
(
"X"
),
ctx
->
GetInputDim
(
"Y"
),
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"The dimensions of X should be equal with the dimensions of Y. "
"The dimensions of X should be equal with the dimensions of Y. "
"But received the dimensions of X is [%s], the dimensions of Y is "
"But received the dimensions of X is [%s], the dimensions of Y is [%s]"
,
"[%s]"
,
ctx
->
GetInputDim
(
"X"
),
ctx
->
GetInputDim
(
"Y"
)));
ctx
->
GetInputDim
(
"X"
),
ctx
->
GetInputDim
(
"Y"
)));
ctx
->
ShareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareDim
(
"X"
,
/*->*/
"Out"
);
...
@@ -51,10 +47,8 @@ public:
...
@@ -51,10 +47,8 @@ public:
class
MpcElementwiseSubOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
MpcElementwiseSubOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
void
Make
()
override
{
void
Make
()
override
{
AddInput
(
"X"
,
AddInput
(
"X"
,
"(Tensor), The first input tensor of mpc elementwise sub op."
);
"(Tensor), The first input tensor of mpc elementwise sub op."
);
AddInput
(
"Y"
,
"(Tensor), The second input tensor of mpc elementwise sub op."
);
AddInput
(
"Y"
,
"(Tensor), The second input tensor of mpc elementwise sub op."
);
AddOutput
(
"Out"
,
"(Tensor), The output tensor of mpc elementwise sub op."
);
AddOutput
(
"Out"
,
"(Tensor), The output tensor of mpc elementwise sub op."
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
MPC elementwise sub Operator.
MPC elementwise sub Operator.
...
@@ -86,21 +80,19 @@ public:
...
@@ -86,21 +80,19 @@ public:
};
};
template
<
typename
T
>
template
<
typename
T
>
class
MpcElementwiseSubGradMaker
:
public
framework
::
SingleGradOp
DescMaker
{
class
MpcElementwiseSubGradMaker
:
public
framework
::
SingleGradOp
Maker
<
T
>
{
public:
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDesc
Maker
;
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOp
Maker
;
protected:
protected:
std
::
unique_ptr
<
T
>
Apply
()
const
override
{
void
Apply
(
GradOpPtr
<
T
>
grad
)
const
override
{
std
::
unique_ptr
<
T
>
retv
(
new
T
());
grad
->
SetType
(
"mpc_elementwise_sub_grad"
);
retv
->
SetType
(
"mpc_elementwise_sub_grad"
);
grad
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
retv
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
grad
->
SetInput
(
"Y"
,
this
->
Input
(
"Y"
));
retv
->
SetInput
(
"Y"
,
this
->
Input
(
"Y"
));
grad
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
retv
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
grad
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
retv
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
grad
->
SetOutput
(
framework
::
GradVarName
(
"Y"
),
this
->
InputGrad
(
"Y"
));
retv
->
SetOutput
(
framework
::
GradVarName
(
"Y"
),
this
->
InputGrad
(
"Y"
));
grad
->
SetAttrMap
(
this
->
Attrs
());
retv
->
SetAttrMap
(
this
->
Attrs
());
return
retv
;
}
}
};
};
...
@@ -118,6 +110,6 @@ REGISTER_OP_CPU_KERNEL(
...
@@ -118,6 +110,6 @@ REGISTER_OP_CPU_KERNEL(
mpc_elementwise_sub
,
mpc_elementwise_sub
,
ops
::
MpcElementwiseSubKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
ops
::
MpcElementwiseSubKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
REGISTER_OP_CPU_KERNEL
(
mpc_elementwise_sub_grad
,
REGISTER_OP_CPU_KERNEL
(
ops
::
MpcElementwiseSubGradKernel
<
mpc_elementwise_sub_grad
,
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
ops
::
MpcElementwiseSubGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
core/paddlefl_mpc/operators/mpc_elementwise_sub_op.h
浏览文件 @
a1a9bf6b
/
/
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/
*
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.
limitations under the License. */
// This op is different with elementwise_sub of PaddlePaddle.
// This op is different with elementwise_sub of PaddlePaddle.
// We only consider that the dimensions of X is equal with the dimensions of Y.
// We only consider that the dimensions of X is equal with the dimensions of Y.
#pragma once
#pragma once
#include "mpc_op.h"
#include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -33,8 +32,7 @@ public:
...
@@ -33,8 +32,7 @@ public:
auto
*
out_t
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
out_t
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
out
=
out_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
out
=
out_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
sub
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
sub
(
in_x_t
,
in_y_t
,
out_t
);
in_x_t
,
in_y_t
,
out_t
);
}
}
};
};
...
@@ -56,11 +54,11 @@ public:
...
@@ -56,11 +54,11 @@ public:
}
}
if
(
dy
)
{
if
(
dy
)
{
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
neg
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
neg
(
dout
,
dy
);
dout
,
dy
);
}
}
}
}
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
core/paddlefl_mpc/operators/mpc_init_op.cc
浏览文件 @
a1a9bf6b
/
/
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/
*
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.
limitations under the License. */
// Description:
// Description:
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_config.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_config.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -26,10 +26,10 @@ using mpc::Aby3Config;
...
@@ -26,10 +26,10 @@ using mpc::Aby3Config;
class
MpcInitOp
:
public
framework
::
OperatorBase
{
class
MpcInitOp
:
public
framework
::
OperatorBase
{
public:
public:
MpcInitOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
MpcInitOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
out
puts
,
const
framework
::
VariableNameMap
&
in
puts
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
VariableNameMap
&
outputs
,
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
...
@@ -55,24 +55,26 @@ public:
...
@@ -55,24 +55,26 @@ public:
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Where2 Operator.
Where2 Operator.
)DOC"
);
)DOC"
);
AddAttr
<
std
::
string
>
(
"protocol_name"
,
"(string , default aby3)"
AddAttr
<
std
::
string
>
(
"protocol_name"
,
"(string , default aby3)"
"protocol name"
)
"protocol name"
)
.
SetDefault
({
"aby3"
});
.
SetDefault
({
"aby3"
});
AddAttr
<
int
>
(
"role"
,
"trainer role."
).
SetDefault
(
0
);
AddAttr
<
int
>
(
"role"
,
"trainer role."
).
SetDefault
(
0
);
AddAttr
<
std
::
string
>
(
"local_addr"
,
"(string, default localhost)"
AddAttr
<
std
::
string
>
(
"local_addr"
,
"(string, default localhost)"
"local addr"
)
"local addr"
)
.
SetDefault
({
"localhost"
});
.
SetDefault
({
"localhost"
});
AddAttr
<
std
::
string
>
(
"net_server_addr"
,
"(string, default localhost)"
AddAttr
<
std
::
string
>
(
"net_server_addr"
,
"(string, default localhost)"
"net server addr"
)
"net server addr"
)
.
SetDefault
({
"localhost"
});
.
SetDefault
({
"localhost"
});
AddAttr
<
int
>
(
"net_server_port"
,
"net server port, default to 6539."
)
AddAttr
<
int
>
(
"net_server_port"
,
"net server port, default to 6539."
).
SetDefault
(
6539
);
.
SetDefault
(
6539
);
}
}
};
};
class
MpcInitOpShapeInference
:
public
framework
::
InferShapeBase
{
class
MpcInitOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
};
}
// namespace operators
}
// namespace operators
...
@@ -80,5 +82,7 @@ public:
...
@@ -80,5 +82,7 @@ public:
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
mpc_init
,
ops
::
MpcInitOp
,
ops
::
MpcInitOpMaker
,
REGISTER_OPERATOR
(
ops
::
MpcInitOpShapeInference
);
mpc_init
,
ops
::
MpcInitOp
,
ops
::
MpcInitOpMaker
,
ops
::
MpcInitOpShapeInference
);
core/paddlefl_mpc/operators/mpc_mean_op.cc
浏览文件 @
a1a9bf6b
/
/
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/
*
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.
limitations under the License. */
#include "mpc_mean_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_mean_op.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -24,13 +24,13 @@ class MpcMeanOp : public framework::OperatorWithKernel {
...
@@ -24,13 +24,13 @@ class MpcMeanOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
NotFound
(
ctx
->
HasInput
(
"X"
),
true
,
"Input(X) of MpcMeanOp should not be null."
));
platform
::
errors
::
NotFound
(
"Input(X) of MpcMeanOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
NotFound
(
ctx
->
HasOutput
(
"Out"
),
true
,
"Output(Out) of MpcMeanOp should not be null."
));
platform
::
errors
::
NotFound
(
"Output(Out) of MpcMeanOp should not be null."
));
ctx
->
SetOutputDim
(
"Out"
,
{
2
,
1
});
ctx
->
SetOutputDim
(
"Out"
,
{
2
,
1
});
}
}
};
};
...
@@ -48,9 +48,10 @@ MPC mean Operator calculates the mean of all elements in X.
...
@@ -48,9 +48,10 @@ MPC mean Operator calculates the mean of all elements in X.
class
MpcMeanOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
class
MpcMeanOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
GetInputOutputWithSameType
()
GetInputOutputWithSameType
()
const
override
{
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
}
};
};
...
@@ -63,21 +64,20 @@ public:
...
@@ -63,21 +64,20 @@ public:
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
framework
::
GradVarName
(
"X"
));
ctx
->
ShareLoD
(
"X"
,
framework
::
GradVarName
(
"X"
));
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
class
MpcMeanOpGradMaker
:
public
framework
::
SingleGradOp
DescMaker
{
class
MpcMeanOpGradMaker
:
public
framework
::
SingleGradOp
Maker
<
T
>
{
public:
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDesc
Maker
;
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOp
Maker
;
protected:
protected:
std
::
unique_ptr
<
T
>
Apply
()
const
override
{
void
Apply
(
GradOpPtr
<
T
>
grad
)
const
override
{
std
::
unique_ptr
<
T
>
retv
(
new
T
());
grad
->
SetType
(
"mpc_mean_grad"
);
retv
->
SetType
(
"mpc_mean_grad"
);
grad
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
retv
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
grad
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
retv
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
grad
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
retv
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
return
retv
;
}
}
};
};
...
@@ -85,14 +85,16 @@ protected:
...
@@ -85,14 +85,16 @@ protected:
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
mpc_mean
,
ops
::
MpcMeanOp
,
ops
::
MpcMeanOpMaker
,
REGISTER_OPERATOR
(
mpc_mean
,
ops
::
MpcMeanOp
,
ops
::
MpcMeanOpMaker
,
ops
::
MpcMeanOpInferVarType
,
ops
::
MpcMeanOpInferVarType
,
ops
::
MpcMeanOpGradMaker
<
paddle
::
framework
::
OpDesc
>
);
ops
::
MpcMeanOpGradMaker
<
paddle
::
framework
::
OpDesc
>
);
REGISTER_OPERATOR
(
mpc_mean_grad
,
ops
::
MpcMeanGradOp
);
REGISTER_OPERATOR
(
mpc_mean_grad
,
ops
::
MpcMeanGradOp
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
mpc_mean
,
ops
::
MpcMeanKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
mpc_mean
,
ops
::
MpcMeanKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
mpc_mean_grad
,
mpc_mean_grad
,
...
...
core/paddlefl_mpc/operators/mpc_mean_op.h
浏览文件 @
a1a9bf6b
/
/
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/
*
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.
limitations under the License. */
#pragma once
#pragma once
#include "mpc_op.h"
#include "mpc_op.h"
#include "paddle/fluid/framework/eigen.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -33,10 +32,8 @@ public:
...
@@ -33,10 +32,8 @@ public:
auto
*
out_t
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
out_t
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
out_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
out_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
double
scale
=
1.0
/
(
in_x_t
->
numel
()
/
2.0
);
double
scale
=
1.0
/
(
in_x_t
->
numel
()
/
2.0
);
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
sum
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
sum
(
in_x_t
,
out_t
);
in_x_t
,
out_t
);
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
scale
(
out_t
,
scale
,
out_t
);
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
scale
(
out_t
,
scale
,
out_t
);
}
}
};
};
...
@@ -45,8 +42,7 @@ class MpcMeanGradKernel : public MpcOpKernel<T> {
...
@@ -45,8 +42,7 @@ class MpcMeanGradKernel : public MpcOpKernel<T> {
public:
public:
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
dout
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
dout
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
PADDLE_ENFORCE
(
dout
->
numel
()
==
2
,
PADDLE_ENFORCE
(
dout
->
numel
()
==
2
,
"numel of MpcMean Gradient should be 2."
);
"numel of MpcMean Gradient should be 2."
);
auto
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
dout_data
=
dout
->
data
<
T
>
();
auto
dout_data
=
dout
->
data
<
T
>
();
...
@@ -60,11 +56,11 @@ public:
...
@@ -60,11 +56,11 @@ public:
}
}
double
scale_factor
=
1.0
/
(
dx
->
numel
()
/
2
);
double
scale_factor
=
1.0
/
(
dx
->
numel
()
/
2
);
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
scale
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
scale
(
dx
,
scale_factor
,
dx
);
dx
,
scale_factor
,
dx
);
}
}
}
}
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
core/paddlefl_mpc/operators/mpc_mul_op.cc
浏览文件 @
a1a9bf6b
/
/
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/
*
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.
limitations under the License. */
#include "mpc_mul_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_mul_op.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -24,16 +24,16 @@ class MpcMulOp : public framework::OperatorWithKernel {
...
@@ -24,16 +24,16 @@ class MpcMulOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
NotFound
(
ctx
->
HasInput
(
"X"
),
true
,
"Input(X) of Mpc MulOp should not be null."
));
platform
::
errors
::
NotFound
(
"Input(X) of Mpc MulOp should not be null."
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Y"
),
true
,
ctx
->
HasInput
(
"Y"
),
true
,
platform
::
errors
::
NotFound
(
"Input(Y) of MpcMulOp should not be null."
));
platform
::
errors
::
NotFound
(
"Input(Y) of MpcMulOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
NotFound
(
ctx
->
HasOutput
(
"Out"
),
true
,
"Output(Out) of MpcMulOp should not be null."
));
platform
::
errors
::
NotFound
(
"Output(Out) of MpcMulOp should not be null."
));
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
...
@@ -86,8 +86,8 @@ public:
...
@@ -86,8 +86,8 @@ public:
x_dims
,
x_mat_width
,
y_dims
,
y_mat_height
));
x_dims
,
x_mat_width
,
y_dims
,
y_mat_height
));
std
::
vector
<
int64_t
>
output_dims
;
std
::
vector
<
int64_t
>
output_dims
;
output_dims
.
reserve
(
static_cast
<
size_t
>
(
1
+
x_num_col_dims
+
y_dims
.
size
()
-
output_dims
.
reserve
(
y_num_col_dims
));
static_cast
<
size_t
>
(
1
+
x_num_col_dims
+
y_dims
.
size
()
-
y_num_col_dims
));
for
(
int
i
=
0
;
i
<=
x_num_col_dims
;
++
i
)
{
// i=0, batch_size (share id)
for
(
int
i
=
0
;
i
<=
x_num_col_dims
;
++
i
)
{
// i=0, batch_size (share id)
output_dims
.
push_back
(
x_dims
[
i
]);
output_dims
.
push_back
(
x_dims
[
i
]);
...
@@ -153,7 +153,8 @@ public:
...
@@ -153,7 +153,8 @@ public:
"same purpose as scale_weights in OPs that support quantization."
"same purpose as scale_weights in OPs that support quantization."
"Only to be used with MKL-DNN INT8"
)
"Only to be used with MKL-DNN INT8"
)
.
SetDefault
({
1.0
f
});
.
SetDefault
({
1.0
f
});
AddAttr
<
float
>
(
"scale_out"
,
"scale_out to be used for int8 output data."
AddAttr
<
float
>
(
"scale_out"
,
"scale_out to be used for int8 output data."
"Only used with MKL-DNN INT8"
)
"Only used with MKL-DNN INT8"
)
.
SetDefault
(
1.0
f
);
.
SetDefault
(
1.0
f
);
AddAttr
<
bool
>
(
AddAttr
<
bool
>
(
...
@@ -169,9 +170,10 @@ MPC mul Operator.
...
@@ -169,9 +170,10 @@ MPC mul Operator.
class
MpcMulOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
class
MpcMulOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
GetInputOutputWithSameType
()
GetInputOutputWithSameType
()
const
override
{
const
override
{
return
std
::
unordered_map
<
std
::
string
,
std
::
string
>
{{
"X"
,
/*->*/
"Out"
}};
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
}
};
};
...
@@ -202,36 +204,37 @@ public:
...
@@ -202,36 +204,37 @@ public:
};
};
template
<
typename
T
>
template
<
typename
T
>
class
MpcMulOpGradMaker
:
public
framework
::
SingleGradOp
DescMaker
{
class
MpcMulOpGradMaker
:
public
framework
::
SingleGradOp
Maker
<
T
>
{
public:
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDesc
Maker
;
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOp
Maker
;
protected:
protected:
std
::
unique_ptr
<
T
>
Apply
()
const
override
{
void
Apply
(
GradOpPtr
<
T
>
grad
)
const
override
{
std
::
unique_ptr
<
T
>
retv
(
new
T
());
grad
->
SetType
(
"mpc_mul_grad"
);
retv
->
SetType
(
"mpc_mul_grad"
);
grad
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
retv
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
grad
->
SetInput
(
"Y"
,
this
->
Input
(
"Y"
));
retv
->
SetInput
(
"Y"
,
this
->
Input
(
"Y"
));
grad
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
retv
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
grad
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
retv
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
grad
->
SetOutput
(
framework
::
GradVarName
(
"Y"
),
this
->
InputGrad
(
"Y"
));
retv
->
SetOutput
(
framework
::
GradVarName
(
"Y"
),
this
->
InputGrad
(
"Y"
));
grad
->
SetAttrMap
(
this
->
Attrs
());
retv
->
SetAttrMap
(
this
->
Attrs
());
return
retv
;
}
}
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
mpc_mul
,
ops
::
MpcMulOp
,
ops
::
MpcMulOpMaker
,
REGISTER_OPERATOR
(
mpc_mul
,
ops
::
MpcMulOp
,
ops
::
MpcMulOpMaker
,
ops
::
MpcMulOpInferVarType
,
ops
::
MpcMulOpInferVarType
,
ops
::
MpcMulOpGradMaker
<
paddle
::
framework
::
OpDesc
>
);
ops
::
MpcMulOpGradMaker
<
paddle
::
framework
::
OpDesc
>
);
REGISTER_OPERATOR
(
mpc_mul_grad
,
ops
::
MpcMulGradOp
);
REGISTER_OPERATOR
(
mpc_mul_grad
,
ops
::
MpcMulGradOp
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
mpc_mul
,
ops
::
MpcMulKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
mpc_mul
,
ops
::
MpcMulKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
mpc_mul_grad
,
mpc_mul_grad
,
...
...
core/paddlefl_mpc/operators/mpc_mul_op.h
浏览文件 @
a1a9bf6b
/
/
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/
*
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.
limitations under the License. */
#pragma once
#pragma once
#include "mpc_op.h"
#include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -48,7 +47,7 @@ public:
...
@@ -48,7 +47,7 @@ public:
}
}
for
(
size_t
i
=
1
;
i
<
y_dims
.
size
();
i
++
)
{
for
(
size_t
i
=
1
;
i
<
y_dims
.
size
();
i
++
)
{
if
(
i
<=
y_num_col_dims
)
{
if
(
i
<=
y_num_col_dims
)
{
x
_mat_width
*=
y_dims
[
i
];
y
_mat_width
*=
y_dims
[
i
];
}
else
{
}
else
{
y_mat_height
*=
y_dims
[
i
];
y_mat_height
*=
y_dims
[
i
];
}
}
...
@@ -59,13 +58,8 @@ public:
...
@@ -59,13 +58,8 @@ public:
x_matrix
.
ShareDataWith
(
*
x
);
x_matrix
.
ShareDataWith
(
*
x
);
y_matrix
.
ShareDataWith
(
*
y
);
y_matrix
.
ShareDataWith
(
*
y
);
if
(
x_dims
.
size
()
>
3
)
{
x_matrix
.
Resize
({
2
,
x_mat_width
,
x_mat_height
});
x_matrix
.
Resize
({
2
,
x_mat_width
,
x_mat_height
});
}
if
(
y_dims
.
size
()
>
3
)
{
y_matrix
.
Resize
({
2
,
y_mat_width
,
y_mat_height
});
y_matrix
.
Resize
({
2
,
y_mat_width
,
y_mat_height
});
}
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
@@ -80,15 +74,17 @@ public:
...
@@ -80,15 +74,17 @@ public:
if
(
out_dim
.
size
()
>
3
)
{
if
(
out_dim
.
size
()
>
3
)
{
out
->
Resize
(
out_dim
);
out
->
Resize
(
out_dim
);
}
}
}
}
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
MpcMulGradKernel
:
public
MpcOpKernel
<
T
>
{
class
MpcMulGradKernel
:
public
MpcOpKernel
<
T
>
{
public:
public:
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Y"
);
auto
*
y
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Y"
);
auto
*
dout
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dout
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dx
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
*
dy
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Y"
));
...
@@ -125,17 +121,9 @@ public:
...
@@ -125,17 +121,9 @@ public:
y_matrix
.
ShareDataWith
(
*
y
);
y_matrix
.
ShareDataWith
(
*
y
);
dout_matrix
.
ShareDataWith
(
*
dout
);
dout_matrix
.
ShareDataWith
(
*
dout
);
if
(
x_dims
.
size
()
>
3
)
{
x_matrix
.
Resize
({
2
,
x_mat_width
,
x_mat_height
});
x_matrix
.
Resize
({
2
,
x_mat_width
,
x_mat_height
});
}
if
(
y_dims
.
size
()
>
3
)
{
y_matrix
.
Resize
({
2
,
y_mat_width
,
y_mat_height
});
y_matrix
.
Resize
({
2
,
y_mat_width
,
y_mat_height
});
}
if
(
dout_dims
.
size
()
>
3
)
{
dout_matrix
.
Resize
({
2
,
x_mat_width
,
y_mat_height
});
dout_matrix
.
Resize
({
2
,
x_mat_width
,
y_mat_height
});
}
if
(
dx
!=
nullptr
)
{
if
(
dx
!=
nullptr
)
{
dx
->
set_lod
(
x
->
lod
());
dx
->
set_lod
(
x
->
lod
());
...
@@ -149,15 +137,10 @@ public:
...
@@ -149,15 +137,10 @@ public:
x_matrix_trans
.
mutable_data
<
T
>
(
x
->
dims
(),
ctx
.
GetPlace
());
x_matrix_trans
.
mutable_data
<
T
>
(
x
->
dims
(),
ctx
.
GetPlace
());
y_matrix_trans
.
mutable_data
<
T
>
(
y
->
dims
(),
ctx
.
GetPlace
());
y_matrix_trans
.
mutable_data
<
T
>
(
y
->
dims
(),
ctx
.
GetPlace
());
if
(
x_dims
.
size
()
>=
3
)
{
x_matrix_trans
.
Resize
({
2
,
x_mat_height
,
x_mat_width
});
x_matrix_trans
.
Resize
({
2
,
x_mat_height
,
x_mat_width
});
}
if
(
y_dims
.
size
()
>=
3
)
{
y_matrix_trans
.
Resize
({
2
,
y_mat_height
,
y_mat_width
});
y_matrix_trans
.
Resize
({
2
,
y_mat_height
,
y_mat_width
});
}
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
const
int
Rank
=
3
;
const
int
Rank
=
3
;
Eigen
::
array
<
int
,
Rank
>
permute
;
Eigen
::
array
<
int
,
Rank
>
permute
;
...
@@ -172,7 +155,7 @@ public:
...
@@ -172,7 +155,7 @@ public:
}
}
auto
eigen_in
=
framework
::
EigenTensor
<
T
,
Rank
>::
From
(
y_matrix
);
auto
eigen_in
=
framework
::
EigenTensor
<
T
,
Rank
>::
From
(
y_matrix
);
auto
eigen_out
=
framework
::
EigenTensor
<
T
,
Rank
>::
From
(
y_matrix_trans
);
auto
eigen_out
=
framework
::
EigenTensor
<
T
,
Rank
>::
From
(
y_matrix_trans
);
auto
*
dev
=
dev_ctx
.
eigen_device
();
auto
*
dev
=
dev_ctx
.
eigen_device
();
eigen_out
.
device
(
*
dev
)
=
eigen_in
.
shuffle
(
permute
);
eigen_out
.
device
(
*
dev
)
=
eigen_in
.
shuffle
(
permute
);
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
matmul
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
matmul
(
...
@@ -191,7 +174,7 @@ public:
...
@@ -191,7 +174,7 @@ public:
auto
eigen_in
=
framework
::
EigenTensor
<
T
,
Rank
>::
From
(
x_matrix
);
auto
eigen_in
=
framework
::
EigenTensor
<
T
,
Rank
>::
From
(
x_matrix
);
auto
eigen_out
=
framework
::
EigenTensor
<
T
,
Rank
>::
From
(
x_matrix_trans
);
auto
eigen_out
=
framework
::
EigenTensor
<
T
,
Rank
>::
From
(
x_matrix_trans
);
auto
*
dev
=
dev_ctx
.
eigen_device
();
auto
*
dev
=
dev_ctx
.
eigen_device
();
eigen_out
.
device
(
*
dev
)
=
eigen_in
.
shuffle
(
permute
);
eigen_out
.
device
(
*
dev
)
=
eigen_in
.
shuffle
(
permute
);
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
matmul
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
matmul
(
...
@@ -206,3 +189,4 @@ public:
...
@@ -206,3 +189,4 @@ public:
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
core/paddlefl_mpc/operators/mpc_op.h
浏览文件 @
a1a9bf6b
/
/
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/
*
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.
limitations under the License. */
// Description:
// Description:
#pragma once
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "core/privc3/circuit_context.h"
#include "core/privc3/circuit_context.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
class
MpcOpKernel
:
public
framework
::
OpKernelBase
{
template
<
typename
T
>
class
MpcOpKernel
:
public
framework
::
OpKernelBase
{
public:
public:
using
ELEMENT_TYPE
=
T
;
using
ELEMENT_TYPE
=
T
;
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
(),
PADDLE_ENFORCE_NOT_NULL
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
(),
"Mpc protocol is not yet initialized in executor"
);
"Mpc protocol is not yet initialized in executor"
);
std
::
shared_ptr
<
aby3
::
CircuitContext
>
mpc_ctx
(
std
::
shared_ptr
<
aby3
::
CircuitContext
>
mpc_ctx
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_context
());
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_context
());
mpc
::
ContextHolder
::
template
run_with_context
<
>(
&
ctx
,
mpc_ctx
,
mpc
::
ContextHolder
::
template
run_with_context
<
>(
&
ctx
,
mpc_ctx
,
[
&
]
{
ComputeImpl
(
ctx
);
});
[
&
]
{
ComputeImpl
(
ctx
);
});
}
}
virtual
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
ctx
)
const
=
0
;
virtual
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
ctx
)
const
=
0
;
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
core/paddlefl_mpc/operators/mpc_relu_op.cc
浏览文件 @
a1a9bf6b
...
@@ -18,20 +18,20 @@
...
@@ -18,20 +18,20 @@
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
//
forward op defination
//forward op defination
class
MpcReluOp
:
public
framework
::
OperatorWithKernel
{
class
MpcReluOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
ctx
->
SetOutputDim
(
"Y"
,
in_dims
);
ctx
->
SetOutputDim
(
"Y"
,
in_dims
);
}
}
};
};
//
forward input & output defination
//
forward input & output defination
class
MpcReluOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
MpcReluOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
void
Make
()
override
{
void
Make
()
override
{
AddInput
(
"X"
,
"The input tensor."
);
AddInput
(
"X"
,
"The input tensor."
);
AddOutput
(
"Y"
,
"Output of relu_op"
);
AddOutput
(
"Y"
,
"Output of relu_op"
);
...
@@ -41,31 +41,30 @@ Mpc Relu Operator.
...
@@ -41,31 +41,30 @@ Mpc Relu Operator.
}
}
};
};
//
backward op defination
//backward op defination
class
MpcReluGradOp
:
public
framework
::
OperatorWithKernel
{
class
MpcReluGradOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
in_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Y"
));
auto
in_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Y"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
in_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
in_dims
);
}
}
};
};
//
backward type, input & output defination
//backward type, input & output defination
template
<
typename
T
>
template
<
typename
T
>
class
MpcReluGradMaker
:
public
framework
::
SingleGradOp
DescMaker
{
class
MpcReluGradMaker
:
public
framework
::
SingleGradOp
Maker
<
T
>
{
public:
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDesc
Maker
;
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOp
Maker
;
std
::
unique_ptr
<
T
>
Apply
()
const
override
{
protected:
auto
*
op
=
new
T
();
void
Apply
(
GradOpPtr
<
T
>
grad
)
const
override
{
op
->
SetType
(
"mpc_relu_grad"
);
grad
->
SetType
(
"mpc_relu_grad"
);
op
->
SetInput
(
"Y"
,
this
->
Output
(
"Y"
));
grad
->
SetInput
(
"Y"
,
this
->
Output
(
"Y"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Y"
),
this
->
OutputGrad
(
"Y"
));
grad
->
SetInput
(
framework
::
GradVarName
(
"Y"
),
this
->
OutputGrad
(
"Y"
));
op
->
SetAttrMap
(
this
->
Attrs
());
grad
->
SetAttrMap
(
this
->
Attrs
());
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
grad
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
return
std
::
unique_ptr
<
T
>
(
op
);
}
}
};
};
...
@@ -76,8 +75,12 @@ namespace ops = paddle::operators;
...
@@ -76,8 +75,12 @@ namespace ops = paddle::operators;
using
CPU
=
paddle
::
platform
::
CPUDeviceContext
;
using
CPU
=
paddle
::
platform
::
CPUDeviceContext
;
REGISTER_OPERATOR
(
mpc_relu
,
ops
::
MpcReluOp
,
ops
::
MpcReluOpMaker
,
REGISTER_OPERATOR
(
mpc_relu
,
ops
::
MpcReluOp
,
ops
::
MpcReluOpMaker
,
ops
::
MpcReluGradMaker
<
paddle
::
framework
::
OpDesc
>
);
ops
::
MpcReluGradMaker
<
paddle
::
framework
::
OpDesc
>
);
REGISTER_OPERATOR
(
mpc_relu_grad
,
ops
::
MpcReluGradOp
);
REGISTER_OPERATOR
(
mpc_relu_grad
,
ops
::
MpcReluGradOp
);
REGISTER_OP_CPU_KERNEL
(
mpc_relu
,
ops
::
MpcReluKernel
<
CPU
,
int64_t
>
);
REGISTER_OP_CPU_KERNEL
(
mpc_relu
,
REGISTER_OP_CPU_KERNEL
(
mpc_relu_grad
,
ops
::
MpcReluGradKernel
<
CPU
,
int64_t
>
);
ops
::
MpcReluKernel
<
CPU
,
int64_t
>
);
REGISTER_OP_CPU_KERNEL
(
mpc_relu_grad
,
ops
::
MpcReluGradKernel
<
CPU
,
int64_t
>
);
core/paddlefl_mpc/operators/mpc_relu_op.h
浏览文件 @
a1a9bf6b
...
@@ -14,43 +14,37 @@
...
@@ -14,43 +14,37 @@
#pragma once
#pragma once
#include "mpc_op.h"
#include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
//
Define forward computation
//Define forward computation
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
MpcReluKernel
:
public
MpcOpKernel
<
T
>
{
class
MpcReluKernel
:
public
MpcOpKernel
<
T
>
{
public:
public:
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
Tensor
*
in_t
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
in_t
=
ctx
.
Input
<
Tensor
>
(
"X"
);
Tensor
*
out_t
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
Tensor
*
out_t
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
x
=
in_t
->
data
<
T
>
();
auto
x
=
in_t
->
data
<
T
>
();
auto
y
=
out_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
y
=
out_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
PADDLE_ENFORCE_NOT_NULL
(
mpc
::
MpcInstance
::
mpc_protocol
,
PADDLE_ENFORCE_NOT_NULL
(
mpc
::
MpcInstance
::
mpc_protocol
,
"Protocol %s is not yet created in MPC Protocol."
);
"Protocol %s is not yet created in MPC Protocol."
);
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
relu
(
in_t
,
out_t
);
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
relu
(
in_t
,
out_t
);
}
}
};
};
//
Define backward computation
//Define backward computation
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
MpcReluGradKernel
:
public
MpcOpKernel
<
T
>
{
class
MpcReluGradKernel
:
public
MpcOpKernel
<
T
>
{
public:
public:
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
dy_t
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
*
dy_t
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
*
y_t
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
y_t
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
dx_t
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dx_t
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
dx
=
dx_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dx
=
dx_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
mpc
::
MpcInstance
::
mpc_instance
()
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
relu_grad
(
y_t
,
dy_t
,
dx_t
,
0.0
);
->
mpc_protocol
()
->
mpc_operators
()
->
relu_grad
(
y_t
,
dy_t
,
dx_t
,
0.0
);
}
}
};
};
}
// namespace operaters
}
// namespace operaters
}
// namespace paddle
}
// namespace paddle
core/paddlefl_mpc/operators/mpc_sgd_op.cc
浏览文件 @
a1a9bf6b
/
/
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/
*
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.
limitations under the License. */
#include "mpc_sgd_op.h"
#include "mpc_sgd_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
...
@@ -55,8 +55,8 @@ public:
...
@@ -55,8 +55,8 @@ public:
}
}
protected:
protected:
framework
::
OpKernelType
framework
::
OpKernelType
GetExpectedKernelType
(
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Param"
);
auto
data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Param"
);
return
framework
::
OpKernelType
(
data_type
,
ctx
.
device_context
());
return
framework
::
OpKernelType
(
data_type
,
ctx
.
device_context
());
}
}
...
@@ -65,19 +65,19 @@ protected:
...
@@ -65,19 +65,19 @@ protected:
class
MpcSGDOpInferVarType
:
public
framework
::
VarTypeInference
{
class
MpcSGDOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
&
input_var_n
=
ctx
->
Input
(
"Param"
)[
0
];
auto
in_var_type
=
ctx
->
GetInputType
(
"Param"
);
auto
in_var_type
=
ctx
->
GetType
(
input_var_n
);
PADDLE_ENFORCE
(
in_var_type
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
PADDLE_ENFORCE
(
in_var_type
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
in_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
in_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
,
"The input Var's type should be LoDtensor or SelectedRows,"
"The input Var's type should be LoDtensor or SelectedRows,"
" but the received var(%s)'s type is %s"
,
" but the received var(%s)'s type is %s"
,
input_var_n
,
in_var_type
);
ctx
->
InputVarName
(
"Param"
),
in_var_type
);
ctx
->
SetOutputType
(
"ParamOut"
,
in_var_type
);
for
(
auto
&
out_var_n
:
ctx
->
Output
(
"ParamOut"
))
{
//for (auto &out_var_n : framework::StaticGraphVarTypeInference::Output(ctx,
"ParamOut")) {
if
(
ctx
->
Get
Type
(
out_var_n
)
!=
in_var_type
)
{
// if (ctx->GetVar
Type(out_var_n) != in_var_type) {
ctx
->
SetType
(
out_var_n
,
in_var_type
);
//
ctx->SetType(out_var_n, in_var_type);
}
//
}
}
//
}
}
}
};
};
...
@@ -108,7 +108,7 @@ $$param\_out = param - learning\_rate * grad$$
...
@@ -108,7 +108,7 @@ $$param\_out = param - learning\_rate * grad$$
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
REGISTER_OPERATOR
(
mpc_sgd
,
ops
::
MpcSGDOp
,
ops
::
MpcSGDOpMaker
,
mpc_sgd
,
ops
::
MpcSGDOp
,
ops
::
MpcSGDOpMaker
,
// paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
ops
::
MpcSGDOpInferVarType
);
ops
::
MpcSGDOpInferVarType
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
mpc_sgd
,
ops
::
MpcSGDOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
mpc_sgd
,
ops
::
MpcSGDOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
core/paddlefl_mpc/operators/mpc_sgd_op.h
浏览文件 @
a1a9bf6b
/
/
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/
*
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.
limitations under the License. */
#pragma once
#pragma once
#include "mpc_op.h"
#include "mpc_op.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
MpcSGDOpKernel
:
public
MpcOpKernel
<
T
>
{
class
MpcSGDOpKernel
:
public
MpcOpKernel
<
T
>
{
public:
public:
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE_EQ
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
true
,
PADDLE_ENFORCE_EQ
(
param_var
->
IsType
<
framework
::
LoDTensor
>
(),
true
,
"The Var(%s)'s type should be LoDTensor, "
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
"but the received is %s"
,
ctx
.
Input
s
(
"Param"
).
front
(),
ctx
.
InputName
s
(
"Param"
).
front
(),
framework
::
ToTypeName
(
param_var
->
Type
()));
framework
::
ToTypeName
(
param_var
->
Type
()));
const
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
const
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
PADDLE_ENFORCE_EQ
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
(),
true
,
PADDLE_ENFORCE_EQ
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
(),
true
,
"The Var(%s)'s type should be LoDTensor, "
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s"
,
"but the received is %s"
,
ctx
.
Input
s
(
"Grad"
).
front
(),
ctx
.
InputName
s
(
"Grad"
).
front
(),
framework
::
ToTypeName
(
grad_var
->
Type
()));
framework
::
ToTypeName
(
grad_var
->
Type
()));
const
auto
*
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
const
auto
*
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
...
@@ -49,19 +48,14 @@ public:
...
@@ -49,19 +48,14 @@ public:
PADDLE_ENFORCE_EQ
(
grad
->
numel
(),
sz
);
PADDLE_ENFORCE_EQ
(
grad
->
numel
(),
sz
);
const
double
*
lr
=
learning_rate
->
data
<
double
>
();
const
double
*
lr
=
learning_rate
->
data
<
double
>
();
// const T *param_data = param->data<T>();
// const T *grad_data = grad->data<T>();
T
*
out_data
=
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
PADDLE_ENFORCE_NOT_NULL
(
mpc
::
MpcInstance
::
mpc_protocol
,
PADDLE_ENFORCE_NOT_NULL
(
mpc
::
MpcInstance
::
mpc_protocol
,
"Protocol %s is not yet created in MPC Protocol."
);
"Protocol %s is not yet created in MPC Protocol."
);
// update parameters
// update parameters
framework
::
Tensor
temp
;
framework
::
Tensor
temp
;
temp
.
mutable_data
<
T
>
(
param
->
dims
(),
ctx
.
GetPlace
());
temp
.
mutable_data
<
T
>
(
param
->
dims
(),
ctx
.
GetPlace
());
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
scale
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
scale
(
grad
,
lr
[
0
],
&
temp
);
grad
,
lr
[
0
],
&
temp
);
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
sub
(
param
,
&
temp
,
param_out
);
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
sub
(
param
,
&
temp
,
param_out
);
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
core/paddlefl_mpc/operators/mpc_sigmoid_cross_entropy_with_logits_op.cc
浏览文件 @
a1a9bf6b
...
@@ -117,21 +117,19 @@ MpcSigmoidCrossEntropyWithLogits Operator.
...
@@ -117,21 +117,19 @@ MpcSigmoidCrossEntropyWithLogits Operator.
};
};
template
<
typename
T
>
template
<
typename
T
>
class
MpcSigmoidCrossEntropyWithLogitsGradOpMaker
:
public
framework
::
SingleGradOp
DescMaker
{
class
MpcSigmoidCrossEntropyWithLogitsGradOpMaker
:
public
framework
::
SingleGradOp
Maker
<
T
>
{
public:
public:
using
framework
::
SingleGradOp
DescMaker
::
SingleGradOpDesc
Maker
;
using
framework
::
SingleGradOp
Maker
<
T
>::
SingleGradOp
Maker
;
protected:
protected:
std
::
unique_ptr
<
T
>
Apply
()
const
override
{
void
Apply
(
GradOpPtr
<
T
>
grad
)
const
override
{
std
::
unique_ptr
<
T
>
retv
(
new
T
());
grad
->
SetType
(
"mpc_sigmoid_cross_entropy_with_logits_grad"
);
retv
->
SetType
(
"mpc_sigmoid_cross_entropy_with_logits_grad"
);
grad
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
retv
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
grad
->
SetInput
(
"Label"
,
this
->
Input
(
"Label"
));
retv
->
SetInput
(
"Label"
,
this
->
Input
(
"Label"
));
grad
->
SetInput
(
"Out"
,
this
->
Output
(
"Out"
));
retv
->
SetInput
(
"Out"
,
this
->
Output
(
"Out"
));
grad
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
retv
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
grad
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
retv
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
grad
->
SetAttrMap
(
this
->
Attrs
());
retv
->
SetAttrMap
(
this
->
Attrs
());
return
retv
;
}
}
};
};
...
...
core/paddlefl_mpc/operators/mpc_square_op.cc
浏览文件 @
a1a9bf6b
/
/
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/
*
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.
limitations under the License. */
#include "mpc_square_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_square_op.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
class
MpcSquareOp
:
public
framework
::
OperatorWithKernel
{
class
MpcSquareOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
NotFound
(
ctx
->
HasInput
(
"X"
),
true
,
"Input(X) of MpcSquareOp should not be null."
));
platform
::
errors
::
NotFound
(
"Input(X) of MpcSquareOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
NotFound
(
ctx
->
HasOutput
(
"Out"
),
true
,
"Output(Out) of MpcSquareOp should not be null."
));
platform
::
errors
::
NotFound
(
"Output(Out) of MpcSquareOp should not be null."
));
ctx
->
ShareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
}
...
@@ -59,26 +60,26 @@ public:
...
@@ -59,26 +60,26 @@ public:
};
};
template
<
typename
T
>
template
<
typename
T
>
class
MpcSquareGradOpMaker
:
public
framework
::
SingleGradOp
DescMaker
{
class
MpcSquareGradOpMaker
:
public
framework
::
SingleGradOp
Maker
<
T
>
{
public:
public:
using
framework
::
SingleGradOpDescMaker
::
SingleGradOpDesc
Maker
;
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOp
Maker
;
protected:
protected:
std
::
unique_ptr
<
T
>
Apply
()
const
override
{
void
Apply
(
GradOpPtr
<
T
>
grad
)
const
override
{
std
::
unique_ptr
<
T
>
retv
(
new
T
());
grad
->
SetType
(
"mpc_square_grad"
);
retv
->
SetType
(
"mpc_square_grad"
);
grad
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
retv
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
grad
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
retv
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
grad
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
retv
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
return
retv
;
}
}
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
mpc_square
,
ops
::
MpcSquareOp
,
ops
::
MpcSquareOpMaker
,
REGISTER_OPERATOR
(
mpc_square
,
ops
::
MpcSquareOp
,
ops
::
MpcSquareOpMaker
,
ops
::
MpcSquareGradOpMaker
<
paddle
::
framework
::
OpDesc
>
);
ops
::
MpcSquareGradOpMaker
<
paddle
::
framework
::
OpDesc
>
);
REGISTER_OPERATOR
(
mpc_square_grad
,
ops
::
MpcSquareGradOp
);
REGISTER_OPERATOR
(
mpc_square_grad
,
ops
::
MpcSquareGradOp
);
...
...
core/paddlefl_mpc/operators/mpc_square_op.h
浏览文件 @
a1a9bf6b
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.
limitations under the License. */
#pragma once
#pragma once
#include "mpc_op.h"
#include "mpc_op.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -27,8 +27,7 @@ public:
...
@@ -27,8 +27,7 @@ public:
auto
*
in_x_t
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
in_x_t
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
out_t
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
out_t
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
out_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
out_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
mul
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
mul
(
in_x_t
,
in_x_t
,
out_t
);
in_x_t
,
in_x_t
,
out_t
);
}
}
};
};
...
@@ -43,13 +42,12 @@ public:
...
@@ -43,13 +42,12 @@ public:
// allocate memory on device.
// allocate memory on device.
dx_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
dx_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// dx = dout * 2 * x
// dx = dout * 2 * x
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
scale
(
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
scale
(
in_x_t
,
2.0
,
dx_t
);
in_x_t
,
2.0
,
dx_t
);
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
mul
(
dx_t
,
dout_t
,
dx_t
);
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
mul
(
dx_t
,
dout_t
,
dx_t
);
}
}
}
}
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
core/paddlefl_mpc/operators/mpc_sum_op.cc
浏览文件 @
a1a9bf6b
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <algorithm>
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "mpc_sum_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/framework/op_registry.h"
#include "mpc_sum_op.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -31,11 +30,10 @@ class MpcSumOp : public framework::OperatorWithKernel {
...
@@ -31,11 +30,10 @@ class MpcSumOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
ctx
->
HasInputs
(
"X"
),
true
,
ctx
->
HasInputs
(
"X"
),
true
,
platform
::
errors
::
NotFound
(
platform
::
errors
::
NotFound
(
"Input(X) of MpcElementwiseAddOp should not be null."
));
"Input(X) of MpcElementwiseAddOp should not be null."
));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
ctx
->
HasOutput
(
"Out"
),
true
,
...
@@ -45,7 +43,8 @@ public:
...
@@ -45,7 +43,8 @@ public:
auto
x_dims
=
ctx
->
GetInputsDim
(
"X"
);
auto
x_dims
=
ctx
->
GetInputsDim
(
"X"
);
auto
N
=
x_dims
.
size
();
auto
N
=
x_dims
.
size
();
PADDLE_ENFORCE_GT
(
PADDLE_ENFORCE_GT
(
N
,
0
,
"ShapeError: The input tensor X's dimensions of SumOp "
N
,
0
,
"ShapeError: The input tensor X's dimensions of SumOp "
"should be larger than 0. But received X's dimensions %d, "
"should be larger than 0. But received X's dimensions %d, "
"X's shape = [%s]."
,
"X's shape = [%s]."
,
N
,
&
x_dims
);
N
,
&
x_dims
);
...
@@ -55,7 +54,7 @@ public:
...
@@ -55,7 +54,7 @@ public:
framework
::
DDim
in_dim
({
0
});
framework
::
DDim
in_dim
({
0
});
for
(
size_t
i
=
0
;
i
<
x_dims
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
x_dims
.
size
();
++
i
)
{
auto
&
x_dim
=
x_dims
[
i
];
auto
&
x_dim
=
x_dims
[
i
];
// x_dim.size() == 1 means the real dim of selected rows is [0]
// x_dim.size() == 1 means the real dim of selected rows is [0]
if
(
x_var_types
[
i
]
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
&&
if
(
x_var_types
[
i
]
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
&&
x_dim
.
size
()
==
1
)
{
x_dim
.
size
()
==
1
)
{
...
@@ -99,6 +98,7 @@ public:
...
@@ -99,6 +98,7 @@ public:
ctx
->
SetOutputDim
(
"Out"
,
in_dim
);
ctx
->
SetOutputDim
(
"Out"
,
in_dim
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
}
};
};
class
MpcSumOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
MpcSumOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
@@ -110,7 +110,8 @@ public:
...
@@ -110,7 +110,8 @@ public:
"or LoDTensor, and data types can be: float32, float64, int32, "
"or LoDTensor, and data types can be: float32, float64, int32, "
"int64."
)
"int64."
)
.
AsDuplicable
();
.
AsDuplicable
();
AddOutput
(
"Out"
,
"the sum of input :code:`x`. its shape and data types are "
AddOutput
(
"Out"
,
"the sum of input :code:`x`. its shape and data types are "
"consistent with :code:`x`."
);
"consistent with :code:`x`."
);
AddAttr
<
bool
>
(
"use_mkldnn"
,
AddAttr
<
bool
>
(
"use_mkldnn"
,
"(bool, default false) Only used in mkldnn kernel"
)
"(bool, default false) Only used in mkldnn kernel"
)
...
@@ -121,6 +122,7 @@ public:
...
@@ -121,6 +122,7 @@ public:
}
}
};
};
class
MpcSumGradMaker
:
public
framework
::
GradOpDescMakerBase
{
class
MpcSumGradMaker
:
public
framework
::
GradOpDescMakerBase
{
public:
public:
using
framework
::
GradOpDescMakerBase
::
GradOpDescMakerBase
;
using
framework
::
GradOpDescMakerBase
::
GradOpDescMakerBase
;
...
@@ -131,8 +133,8 @@ public:
...
@@ -131,8 +133,8 @@ public:
grad_ops
.
reserve
(
x_grads
.
size
());
grad_ops
.
reserve
(
x_grads
.
size
());
auto
og
=
OutputGrad
(
"Out"
);
auto
og
=
OutputGrad
(
"Out"
);
std
::
transform
(
x_grads
.
begin
(),
x_grads
.
end
(),
std
::
back_inserter
(
grad_ops
),
std
::
transform
(
x_grads
.
begin
(),
x_grads
.
end
(),
std
::
back_inserter
(
grad_ops
),
[
&
og
](
const
std
::
string
&
x_grad
)
{
[
&
og
](
const
std
::
string
&
x_grad
)
{
auto
*
grad_op
=
new
framework
::
OpDesc
();
auto
*
grad_op
=
new
framework
::
OpDesc
();
grad_op
->
SetType
(
"scale"
);
grad_op
->
SetType
(
"scale"
);
grad_op
->
SetInput
(
"X"
,
og
);
grad_op
->
SetInput
(
"X"
,
og
);
grad_op
->
SetOutput
(
"Out"
,
{
x_grad
});
grad_op
->
SetOutput
(
"Out"
,
{
x_grad
});
...
@@ -151,9 +153,10 @@ DECLARE_INPLACE_OP_INFERER(MpcSumInplace, {"X", "Out"});
...
@@ -151,9 +153,10 @@ DECLARE_INPLACE_OP_INFERER(MpcSumInplace, {"X", "Out"});
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
// REGISTER_OP_WITHOUT_GRADIENT(mpc_sum, ops::MpcSumOp, ops::MpcSumOpMaker);
//REGISTER_OP_WITHOUT_GRADIENT(mpc_sum, ops::MpcSumOp, ops::MpcSumOpMaker);
REGISTER_OPERATOR
(
mpc_sum
,
ops
::
MpcSumOp
,
ops
::
MpcSumOpMaker
,
REGISTER_OPERATOR
(
mpc_sum
,
ops
::
MpcSumOp
,
ops
::
MpcSumGradMaker
,
ops
::
MpcSumInplace
);
ops
::
MpcSumOpMaker
,
ops
::
MpcSumGradMaker
,
ops
::
MpcSumInplace
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
mpc_sum
,
ops
::
MpcSumKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
mpc_sum
,
ops
::
MpcSumKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
core/paddlefl_mpc/operators/mpc_sum_op.h
浏览文件 @
a1a9bf6b
/
/
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/
*
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
// limitations under the License.
limitations under the License. */
#pragma once
#pragma once
#include "mpc_op.h"
#include "mpc_op.h"
...
@@ -45,10 +45,7 @@ public:
...
@@ -45,10 +45,7 @@ public:
auto
&
in_0
=
in_vars
[
0
]
->
Get
<
framework
::
LoDTensor
>
();
auto
&
in_0
=
in_vars
[
0
]
->
Get
<
framework
::
LoDTensor
>
();
auto
&
in_1
=
in_vars
[
1
]
->
Get
<
framework
::
LoDTensor
>
();
auto
&
in_1
=
in_vars
[
1
]
->
Get
<
framework
::
LoDTensor
>
();
if
(
in_0
.
numel
()
&&
in_1
.
numel
())
{
if
(
in_0
.
numel
()
&&
in_1
.
numel
())
{
mpc
::
MpcInstance
::
mpc_instance
()
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
add
(
&
in_0
,
&
in_1
,
out
);
->
mpc_protocol
()
->
mpc_operators
()
->
add
(
&
in_0
,
&
in_1
,
out
);
start
=
2
;
start
=
2
;
}
}
}
}
...
@@ -66,15 +63,12 @@ public:
...
@@ -66,15 +63,12 @@ public:
if
(
in_t
.
numel
()
==
0
)
{
if
(
in_t
.
numel
()
==
0
)
{
continue
;
continue
;
}
}
mpc
::
MpcInstance
::
mpc_instance
()
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
add
(
out
,
&
in_t
,
out
);
->
mpc_protocol
()
->
mpc_operators
()
->
add
(
out
,
&
in_t
,
out
);
}
else
{
}
else
{
PADDLE_THROW
(
"Variable type must be LoDTensor/SelectedRows."
);
PADDLE_THROW
(
"Variable type must be LoDTensor/SelectedRows."
);
}
}
}
}
}
else
{
}
else
{
PADDLE_THROW
(
"Unexpected branch, output variable type is %s"
,
PADDLE_THROW
(
"Unexpected branch, output variable type is %s"
,
framework
::
ToTypeName
(
out_var
->
Type
()));
framework
::
ToTypeName
(
out_var
->
Type
()));
}
}
...
@@ -82,3 +76,4 @@ public:
...
@@ -82,3 +76,4 @@ public:
};
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
core/privc3/boolean_tensor_test.cc
浏览文件 @
a1a9bf6b
此差异已折叠。
点击以展开。
core/privc3/fixedpoint_tensor_test.cc
浏览文件 @
a1a9bf6b
此差异已折叠。
点击以展开。
core/privc3/paddle_tensor_test.cc
浏览文件 @
a1a9bf6b
...
@@ -18,8 +18,8 @@
...
@@ -18,8 +18,8 @@
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/framework/tensor.h"
namespace
aby3
{
namespace
aby3
{
...
@@ -28,10 +28,13 @@ using paddle::framework::Tensor;
...
@@ -28,10 +28,13 @@ using paddle::framework::Tensor;
class
PaddleTensorTest
:
public
::
testing
::
Test
{
class
PaddleTensorTest
:
public
::
testing
::
Test
{
public:
public:
std
::
shared_ptr
<
TensorAdapterFactory
>
_tensor_factory
;
std
::
shared_ptr
<
TensorAdapterFactory
>
_tensor_factory
;
CPUDeviceContext
_cpu_ctx
;
CPUDeviceContext
_cpu_ctx
;
virtual
~
PaddleTensorTest
()
noexcept
{}
void
SetUp
()
{
void
SetUp
()
{
_tensor_factory
=
std
::
make_shared
<
PaddleTensorFactory
>
(
&
_cpu_ctx
);
_tensor_factory
=
std
::
make_shared
<
PaddleTensorFactory
>
(
&
_cpu_ctx
);
}
}
...
@@ -39,21 +42,20 @@ public:
...
@@ -39,21 +42,20 @@ public:
TEST_F
(
PaddleTensorTest
,
factory_test
)
{
TEST_F
(
PaddleTensorTest
,
factory_test
)
{
EXPECT_NO_THROW
(
_tensor_factory
->
template
create
<
int64_t
>());
EXPECT_NO_THROW
(
_tensor_factory
->
template
create
<
int64_t
>());
std
::
vector
<
size_t
>
shape
=
{
2
,
3
};
std
::
vector
<
size_t
>
shape
=
{
2
,
3
};
EXPECT_NO_THROW
(
_tensor_factory
->
template
create
<
int64_t
>(
shape
));
EXPECT_NO_THROW
(
_tensor_factory
->
template
create
<
int64_t
>(
shape
));
}
}
TEST_F
(
PaddleTensorTest
,
ctor_test
)
{
TEST_F
(
PaddleTensorTest
,
ctor_test
)
{
Tensor
t
;
Tensor
t
;
// t holds no memory
// t holds no memory
EXPECT_THROW
({
PaddleTensor
<
int64_t
>
pt
(
&
_cpu_ctx
,
t
);
},
EXPECT_THROW
({
PaddleTensor
<
int64_t
>
pt
(
&
_cpu_ctx
,
t
);
},
::
paddle
::
platform
::
EnforceNotMet
);
::
paddle
::
platform
::
EnforceNotMet
);
t
.
template
mutable_data
<
int64_t
>(
_cpu_ctx
.
GetPlace
());
t
.
template
mutable_data
<
int64_t
>(
_cpu_ctx
.
GetPlace
());
EXPECT_NO_THROW
({
PaddleTensor
<
int64_t
>
pt
(
&
_cpu_ctx
,
t
);
});
EXPECT_NO_THROW
({
PaddleTensor
<
int64_t
>
pt
(
&
_cpu_ctx
,
t
);
});
}
}
TEST_F
(
PaddleTensorTest
,
shape_test
)
{
TEST_F
(
PaddleTensorTest
,
shape_test
)
{
std
::
vector
<
size_t
>
shape
=
{
2
,
3
};
std
::
vector
<
size_t
>
shape
=
{
2
,
3
};
auto
pt
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
EXPECT_EQ
(
shape
.
size
(),
pt
->
shape
().
size
());
EXPECT_EQ
(
shape
.
size
(),
pt
->
shape
().
size
());
...
@@ -65,7 +67,7 @@ TEST_F(PaddleTensorTest, shape_test) {
...
@@ -65,7 +67,7 @@ TEST_F(PaddleTensorTest, shape_test) {
}
}
TEST_F
(
PaddleTensorTest
,
reshape_test
)
{
TEST_F
(
PaddleTensorTest
,
reshape_test
)
{
std
::
vector
<
size_t
>
shape
=
{
2
,
3
};
std
::
vector
<
size_t
>
shape
=
{
2
,
3
};
auto
pt
=
_tensor_factory
->
template
create
<
int64_t
>();
auto
pt
=
_tensor_factory
->
template
create
<
int64_t
>();
pt
->
reshape
(
shape
);
pt
->
reshape
(
shape
);
...
@@ -77,7 +79,7 @@ TEST_F(PaddleTensorTest, reshape_test) {
...
@@ -77,7 +79,7 @@ TEST_F(PaddleTensorTest, reshape_test) {
}
}
TEST_F
(
PaddleTensorTest
,
add_test
)
{
TEST_F
(
PaddleTensorTest
,
add_test
)
{
std
::
vector
<
size_t
>
shape
=
{
1
};
std
::
vector
<
size_t
>
shape
=
{
1
};
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt2
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt2
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
...
@@ -89,7 +91,7 @@ TEST_F(PaddleTensorTest, add_test) {
...
@@ -89,7 +91,7 @@ TEST_F(PaddleTensorTest, add_test) {
}
}
TEST_F
(
PaddleTensorTest
,
sub_test
)
{
TEST_F
(
PaddleTensorTest
,
sub_test
)
{
std
::
vector
<
size_t
>
shape
=
{
1
};
std
::
vector
<
size_t
>
shape
=
{
1
};
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt2
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt2
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
...
@@ -101,7 +103,7 @@ TEST_F(PaddleTensorTest, sub_test) {
...
@@ -101,7 +103,7 @@ TEST_F(PaddleTensorTest, sub_test) {
}
}
TEST_F
(
PaddleTensorTest
,
negative_test
)
{
TEST_F
(
PaddleTensorTest
,
negative_test
)
{
std
::
vector
<
size_t
>
shape
=
{
1
};
std
::
vector
<
size_t
>
shape
=
{
1
};
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
pt0
->
data
()[
0
]
=
2
;
pt0
->
data
()[
0
]
=
2
;
...
@@ -111,7 +113,7 @@ TEST_F(PaddleTensorTest, negative_test) {
...
@@ -111,7 +113,7 @@ TEST_F(PaddleTensorTest, negative_test) {
}
}
TEST_F
(
PaddleTensorTest
,
mul_test
)
{
TEST_F
(
PaddleTensorTest
,
mul_test
)
{
std
::
vector
<
size_t
>
shape
=
{
1
};
std
::
vector
<
size_t
>
shape
=
{
1
};
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt2
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt2
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
...
@@ -123,7 +125,7 @@ TEST_F(PaddleTensorTest, mul_test) {
...
@@ -123,7 +125,7 @@ TEST_F(PaddleTensorTest, mul_test) {
}
}
TEST_F
(
PaddleTensorTest
,
div_test
)
{
TEST_F
(
PaddleTensorTest
,
div_test
)
{
std
::
vector
<
size_t
>
shape
=
{
1
};
std
::
vector
<
size_t
>
shape
=
{
1
};
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt2
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt2
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
...
@@ -135,9 +137,9 @@ TEST_F(PaddleTensorTest, div_test) {
...
@@ -135,9 +137,9 @@ TEST_F(PaddleTensorTest, div_test) {
}
}
TEST_F
(
PaddleTensorTest
,
matmul_test
)
{
TEST_F
(
PaddleTensorTest
,
matmul_test
)
{
std
::
vector
<
size_t
>
shape0
=
{
2
,
3
};
std
::
vector
<
size_t
>
shape0
=
{
2
,
3
};
std
::
vector
<
size_t
>
shape1
=
{
3
,
2
};
std
::
vector
<
size_t
>
shape1
=
{
3
,
2
};
std
::
vector
<
size_t
>
shape2
=
{
2
,
2
};
std
::
vector
<
size_t
>
shape2
=
{
2
,
2
};
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape0
);
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape0
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape1
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape1
);
auto
pt2
=
_tensor_factory
->
template
create
<
int64_t
>(
shape2
);
auto
pt2
=
_tensor_factory
->
template
create
<
int64_t
>(
shape2
);
...
@@ -151,7 +153,7 @@ TEST_F(PaddleTensorTest, matmul_test) {
...
@@ -151,7 +153,7 @@ TEST_F(PaddleTensorTest, matmul_test) {
// | 3 4 5 | x | 2 3 | = | 28 40 |
// | 3 4 5 | x | 2 3 | = | 28 40 |
// | 4 5 |
// | 4 5 |
std
::
vector
<
int64_t
>
res
=
{
10
,
13
,
28
,
40
};
std
::
vector
<
int64_t
>
res
=
{
10
,
13
,
28
,
40
};
bool
eq
=
std
::
equal
(
res
.
begin
(),
res
.
end
(),
pt2
->
data
());
bool
eq
=
std
::
equal
(
res
.
begin
(),
res
.
end
(),
pt2
->
data
());
...
@@ -159,7 +161,7 @@ TEST_F(PaddleTensorTest, matmul_test) {
...
@@ -159,7 +161,7 @@ TEST_F(PaddleTensorTest, matmul_test) {
}
}
TEST_F
(
PaddleTensorTest
,
xor_test
)
{
TEST_F
(
PaddleTensorTest
,
xor_test
)
{
std
::
vector
<
size_t
>
shape
=
{
1
};
std
::
vector
<
size_t
>
shape
=
{
1
};
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt2
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt2
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
...
@@ -171,7 +173,7 @@ TEST_F(PaddleTensorTest, xor_test) {
...
@@ -171,7 +173,7 @@ TEST_F(PaddleTensorTest, xor_test) {
}
}
TEST_F
(
PaddleTensorTest
,
and_test
)
{
TEST_F
(
PaddleTensorTest
,
and_test
)
{
std
::
vector
<
size_t
>
shape
=
{
1
};
std
::
vector
<
size_t
>
shape
=
{
1
};
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt2
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt2
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
...
@@ -183,7 +185,7 @@ TEST_F(PaddleTensorTest, and_test) {
...
@@ -183,7 +185,7 @@ TEST_F(PaddleTensorTest, and_test) {
}
}
TEST_F
(
PaddleTensorTest
,
or_test
)
{
TEST_F
(
PaddleTensorTest
,
or_test
)
{
std
::
vector
<
size_t
>
shape
=
{
1
};
std
::
vector
<
size_t
>
shape
=
{
1
};
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt2
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt2
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
...
@@ -195,7 +197,7 @@ TEST_F(PaddleTensorTest, or_test) {
...
@@ -195,7 +197,7 @@ TEST_F(PaddleTensorTest, or_test) {
}
}
TEST_F
(
PaddleTensorTest
,
not_test
)
{
TEST_F
(
PaddleTensorTest
,
not_test
)
{
std
::
vector
<
size_t
>
shape
=
{
1
};
std
::
vector
<
size_t
>
shape
=
{
1
};
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
pt0
->
data
()[
0
]
=
0
;
pt0
->
data
()[
0
]
=
0
;
...
@@ -205,7 +207,7 @@ TEST_F(PaddleTensorTest, not_test) {
...
@@ -205,7 +207,7 @@ TEST_F(PaddleTensorTest, not_test) {
}
}
TEST_F
(
PaddleTensorTest
,
lshift_test
)
{
TEST_F
(
PaddleTensorTest
,
lshift_test
)
{
std
::
vector
<
size_t
>
shape
=
{
1
};
std
::
vector
<
size_t
>
shape
=
{
1
};
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
pt0
->
data
()[
0
]
=
2
;
pt0
->
data
()[
0
]
=
2
;
...
@@ -215,7 +217,7 @@ TEST_F(PaddleTensorTest, lshift_test) {
...
@@ -215,7 +217,7 @@ TEST_F(PaddleTensorTest, lshift_test) {
}
}
TEST_F
(
PaddleTensorTest
,
rshift_test
)
{
TEST_F
(
PaddleTensorTest
,
rshift_test
)
{
std
::
vector
<
size_t
>
shape
=
{
1
};
std
::
vector
<
size_t
>
shape
=
{
1
};
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
pt0
->
data
()[
0
]
=
2
;
pt0
->
data
()[
0
]
=
2
;
...
@@ -225,7 +227,7 @@ TEST_F(PaddleTensorTest, rshift_test) {
...
@@ -225,7 +227,7 @@ TEST_F(PaddleTensorTest, rshift_test) {
}
}
TEST_F
(
PaddleTensorTest
,
logical_rshift_test
)
{
TEST_F
(
PaddleTensorTest
,
logical_rshift_test
)
{
std
::
vector
<
size_t
>
shape
=
{
1
};
std
::
vector
<
size_t
>
shape
=
{
1
};
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt0
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt1
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
pt0
->
data
()[
0
]
=
-
1
;
pt0
->
data
()[
0
]
=
-
1
;
...
@@ -234,16 +236,17 @@ TEST_F(PaddleTensorTest, logical_rshift_test) {
...
@@ -234,16 +236,17 @@ TEST_F(PaddleTensorTest, logical_rshift_test) {
EXPECT_EQ
(
-
1ull
>>
1
,
pt1
->
data
()[
0
]);
EXPECT_EQ
(
-
1ull
>>
1
,
pt1
->
data
()[
0
]);
}
}
TEST_F
(
PaddleTensorTest
,
scale_test
)
{
TEST_F
(
PaddleTensorTest
,
scale_test
)
{
auto
pt
=
_tensor_factory
->
template
create
<
int64_t
>();
auto
pt
=
_tensor_factory
->
template
create
<
int64_t
>();
auto
pt_
=
dynamic_cast
<
PaddleTensor
<
int64_t
>
*>
(
pt
.
get
());
auto
pt_
=
dynamic_cast
<
PaddleTensor
<
int64_t
>
*>
(
pt
.
get
());
pt_
->
scaling_factor
()
=
1
;
pt_
->
scaling_factor
()
=
1
;
Tensor
t
;
Tensor
t
;
int
dim
[
1
]
=
{
1
};
int
dim
[
1
]
=
{
1
};
paddle
::
framework
::
DDim
ddim
(
dim
,
1
);
paddle
::
framework
::
DDim
ddim
(
dim
,
1
);
t
.
template
mutable_data
<
float
>(
ddim
,
_cpu_ctx
.
GetPlace
());
t
.
template
mutable_data
<
float
>(
ddim
,
_cpu_ctx
.
GetPlace
());
...
@@ -258,11 +261,11 @@ TEST_F(PaddleTensorTest, scale_test) {
...
@@ -258,11 +261,11 @@ TEST_F(PaddleTensorTest, scale_test) {
TEST_F
(
PaddleTensorTest
,
scalar_test
)
{
TEST_F
(
PaddleTensorTest
,
scalar_test
)
{
auto
pt
=
_tensor_factory
->
template
create
<
int64_t
>();
auto
pt
=
_tensor_factory
->
template
create
<
int64_t
>();
auto
pt_
=
dynamic_cast
<
PaddleTensor
<
int64_t
>
*>
(
pt
.
get
());
auto
pt_
=
dynamic_cast
<
PaddleTensor
<
int64_t
>
*>
(
pt
.
get
());
pt_
->
scaling_factor
()
=
1
;
pt_
->
scaling_factor
()
=
1
;
std
::
vector
<
size_t
>
shape
=
{
2
};
std
::
vector
<
size_t
>
shape
=
{
2
};
pt_
->
template
from_float_point_scalar
(
0.25
f
,
shape
,
2
);
pt_
->
template
from_float_point_scalar
(
0.25
f
,
shape
,
2
);
EXPECT_EQ
(
2
,
pt_
->
scaling_factor
());
EXPECT_EQ
(
2
,
pt_
->
scaling_factor
());
...
@@ -271,11 +274,11 @@ TEST_F(PaddleTensorTest, scalar_test) {
...
@@ -271,11 +274,11 @@ TEST_F(PaddleTensorTest, scalar_test) {
}
}
TEST_F
(
PaddleTensorTest
,
slice_test
)
{
TEST_F
(
PaddleTensorTest
,
slice_test
)
{
std
::
vector
<
size_t
>
shape
=
{
2
,
2
};
std
::
vector
<
size_t
>
shape
=
{
2
,
2
};
auto
pt
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
pt
=
_tensor_factory
->
template
create
<
int64_t
>(
shape
);
auto
ret
=
_tensor_factory
->
template
create
<
int64_t
>();
auto
ret
=
_tensor_factory
->
template
create
<
int64_t
>();
auto
pt_
=
dynamic_cast
<
PaddleTensor
<
int64_t
>
*>
(
pt
.
get
());
auto
pt_
=
dynamic_cast
<
PaddleTensor
<
int64_t
>
*>
(
pt
.
get
());
pt_
->
scaling_factor
()
=
1
;
pt_
->
scaling_factor
()
=
1
;
for
(
size_t
i
=
0
;
i
<
4
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
4
;
++
i
)
{
...
...
python/paddle_fl/mpc/framework.py
浏览文件 @
a1a9bf6b
...
@@ -21,14 +21,13 @@ from paddle.fluid import core
...
@@ -21,14 +21,13 @@ from paddle.fluid import core
from
paddle.fluid
import
unique_name
from
paddle.fluid
import
unique_name
from
paddle.fluid.framework
import
Variable
from
paddle.fluid.framework
import
Variable
from
paddle.fluid.framework
import
convert_np_dtype_to_dtype_
from
paddle.fluid.framework
import
convert_np_dtype_to_dtype_
from
paddle.fluid.data_feeder
import
check_type
,
check_dtype
class
MpcVariable
(
Variable
):
class
MpcVariable
(
Variable
):
"""
"""
Extends from paddle.fluid.framework.Variable and rewrite
Extends from paddle.fluid.framework.Variable and rewrite
the __init__ method where the shape is resized.
the __init__ method where the shape is resized.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
block
,
block
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
...
@@ -91,22 +90,22 @@ class MpcVariable(Variable):
...
@@ -91,22 +90,22 @@ class MpcVariable(Variable):
else
:
else
:
old_dtype
=
self
.
dtype
old_dtype
=
self
.
dtype
if
dtype
!=
old_dtype
:
if
dtype
!=
old_dtype
:
raise
ValueError
(
raise
ValueError
(
"MpcVariable {0} has been created before. "
"MpcVariable {0} has been created before. "
"The previous data type is {1}; the new "
"The previous data type is {1}; the new "
"data type is {2}. They are not "
"data type is {2}. They are not "
"matched."
.
format
(
self
.
name
,
old_dtype
,
dtype
))
"matched."
.
format
(
self
.
name
,
old_dtype
,
dtype
))
if
lod_level
is
not
None
:
if
lod_level
is
not
None
:
if
is_new_var
:
if
is_new_var
:
self
.
desc
.
set_lod_level
(
lod_level
)
self
.
desc
.
set_lod_level
(
lod_level
)
else
:
else
:
if
lod_level
!=
self
.
lod_level
:
if
lod_level
!=
self
.
lod_level
:
raise
ValueError
(
raise
ValueError
(
"MpcVariable {0} has been created before. "
"MpcVariable {0} has been created before. "
"The previous lod_level is {1}; the new "
"The previous lod_level is {1}; the new "
"lod_level is {2}. They are not "
"lod_level is {2}. They are not "
"matched"
.
format
(
self
.
name
,
self
.
lod_level
,
lod_level
))
"matched"
.
format
(
self
.
name
,
self
.
lod_level
,
lod_level
))
if
persistable
is
not
None
:
if
persistable
is
not
None
:
if
is_new_var
:
if
is_new_var
:
self
.
desc
.
set_persistable
(
persistable
)
self
.
desc
.
set_persistable
(
persistable
)
...
@@ -156,8 +155,7 @@ class MpcParameter(MpcVariable):
...
@@ -156,8 +155,7 @@ class MpcParameter(MpcVariable):
if
len
(
shape
)
==
0
:
if
len
(
shape
)
==
0
:
raise
ValueError
(
raise
ValueError
(
"The dimensions of shape for MpcParameter must be greater than 0"
"The dimensions of shape for MpcParameter must be greater than 0"
)
)
for
each
in
shape
:
for
each
in
shape
:
if
each
<
0
:
if
each
<
0
:
...
@@ -175,8 +173,7 @@ class MpcParameter(MpcVariable):
...
@@ -175,8 +173,7 @@ class MpcParameter(MpcVariable):
**
kwargs
)
**
kwargs
)
self
.
trainable
=
kwargs
.
get
(
'trainable'
,
True
)
self
.
trainable
=
kwargs
.
get
(
'trainable'
,
True
)
self
.
optimize_attr
=
kwargs
.
get
(
'optimize_attr'
,
self
.
optimize_attr
=
kwargs
.
get
(
'optimize_attr'
,
{
'learning_rate'
:
1.0
})
{
'learning_rate'
:
1.0
})
self
.
regularizer
=
kwargs
.
get
(
'regularizer'
,
None
)
self
.
regularizer
=
kwargs
.
get
(
'regularizer'
,
None
)
...
@@ -203,8 +200,8 @@ class MpcParameter(MpcVariable):
...
@@ -203,8 +200,8 @@ class MpcParameter(MpcVariable):
additional_attr
=
(
"trainable"
,
"optimize_attr"
,
"regularizer"
,
additional_attr
=
(
"trainable"
,
"optimize_attr"
,
"regularizer"
,
"gradient_clip_attr"
,
"do_model_average"
)
"gradient_clip_attr"
,
"do_model_average"
)
for
attr_name
in
additional_attr
:
for
attr_name
in
additional_attr
:
res_str
+=
"%s: %s
\n
"
%
(
res_str
+=
"%s: %s
\n
"
%
(
attr_name
,
attr_name
,
cpt
.
to_text
(
getattr
(
self
,
attr_name
)))
cpt
.
to_text
(
getattr
(
self
,
attr_name
)))
else
:
else
:
res_str
=
MpcVariable
.
to_string
(
self
,
throw_on_error
,
False
)
res_str
=
MpcVariable
.
to_string
(
self
,
throw_on_error
,
False
)
return
res_str
return
res_str
...
@@ -245,8 +242,7 @@ def create_mpc_parameter(block, *args, **kwargs):
...
@@ -245,8 +242,7 @@ def create_mpc_parameter(block, *args, **kwargs):
init_ops_len
=
len
(
init_ops
)
init_ops_len
=
len
(
init_ops
)
if
init_ops_len
>
1
:
if
init_ops_len
>
1
:
raise
RuntimeError
(
"mpc_param "
+
mpc_param
.
name
+
raise
RuntimeError
(
"mpc_param "
+
mpc_param
.
name
+
" is inited by multiple init ops "
+
str
(
" is inited by multiple init ops "
+
str
(
init_ops
))
init_ops
))
elif
init_ops_len
==
1
:
elif
init_ops_len
==
1
:
# TODO(Paddle 1.7): already inited, do nothing, should log a warning
# TODO(Paddle 1.7): already inited, do nothing, should log a warning
pass
pass
...
@@ -272,7 +268,6 @@ def create_mpc_var(block, *args, **kwargs):
...
@@ -272,7 +268,6 @@ def create_mpc_var(block, *args, **kwargs):
kwargs
[
'initializer'
](
var
,
block
)
kwargs
[
'initializer'
](
var
,
block
)
return
var
return
var
def
is_mpc_parameter
(
var
):
def
is_mpc_parameter
(
var
):
"""
"""
Check whether the given variable is an instance of MpcParameter.
Check whether the given variable is an instance of MpcParameter.
...
@@ -282,4 +277,13 @@ def is_mpc_parameter(var):
...
@@ -282,4 +277,13 @@ def is_mpc_parameter(var):
bool: True if the given `var` is an instance of Parameter,
bool: True if the given `var` is an instance of Parameter,
False if not.
False if not.
"""
"""
return
isinstance
(
var
,
MpcParameter
)
return
type
(
var
)
==
MpcParameter
def
check_mpc_variable_and_dtype
(
input
,
input_name
,
expected_dtype
,
op_name
,
extra_message
=
''
):
check_type
(
input
,
input_name
,
MpcVariable
,
op_name
,
extra_message
)
check_dtype
(
input
.
dtype
,
input_name
,
expected_dtype
,
op_name
,
extra_message
)
python/paddle_fl/mpc/layers/__init__.py
浏览文件 @
a1a9bf6b
python/paddle_fl/mpc/layers/basic.py
浏览文件 @
a1a9bf6b
...
@@ -14,9 +14,10 @@
...
@@ -14,9 +14,10 @@
"""
"""
basic mpc op layers.
basic mpc op layers.
"""
"""
from
paddle.fluid.data_feeder
import
check_
typ
e_and_dtype
from
paddle.fluid.data_feeder
import
check_
variabl
e_and_dtype
from
..framework
import
MpcVariable
from
..framework
import
MpcVariable
from
..framework
import
check_mpc_variable_and_dtype
from
..mpc_layer_helper
import
MpcLayerHelper
from
..mpc_layer_helper
import
MpcLayerHelper
__all__
=
[
__all__
=
[
...
@@ -32,8 +33,8 @@ def _elementwise_op(helper):
...
@@ -32,8 +33,8 @@ def _elementwise_op(helper):
assert
x
is
not
None
,
'x cannot be None in {}'
.
format
(
op_type
)
assert
x
is
not
None
,
'x cannot be None in {}'
.
format
(
op_type
)
assert
y
is
not
None
,
'y cannot be None in {}'
.
format
(
op_type
)
assert
y
is
not
None
,
'y cannot be None in {}'
.
format
(
op_type
)
check_
type_and_dtype
(
x
,
'x'
,
MpcVariable
,
[
'int64'
],
op_type
)
check_
mpc_variable_and_dtype
(
x
,
'x'
,
[
'int64'
],
op_type
)
check_
type_and_dtype
(
y
,
'y'
,
MpcVariable
,
[
'int64'
],
op_type
)
check_
mpc_variable_and_dtype
(
y
,
'y'
,
[
'int64'
],
op_type
)
axis
=
helper
.
kwargs
.
get
(
'axis'
,
-
1
)
axis
=
helper
.
kwargs
.
get
(
'axis'
,
-
1
)
use_mkldnn
=
helper
.
kwargs
.
get
(
'use_mkldnn'
,
False
)
use_mkldnn
=
helper
.
kwargs
.
get
(
'use_mkldnn'
,
False
)
...
...
python/paddle_fl/mpc/layers/compare.py
浏览文件 @
a1a9bf6b
# Copyright (c) 20
18
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 20
20
PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
"""
"""
mpc math compare layers.
mpc math compare layers.
"""
"""
from
paddle.fluid.data_feeder
import
check_type_and_dtype
from
..framework
import
MpcVariable
from
..framework
import
MpcVariable
from
..mpc_layer_helper
import
MpcLayerHelper
from
..mpc_layer_helper
import
MpcLayerHelper
...
...
python/paddle_fl/mpc/layers/math.py
浏览文件 @
a1a9bf6b
# Copyright (c) 20
18
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 20
20
PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -14,9 +14,9 @@
...
@@ -14,9 +14,9 @@
"""
"""
mpc math op layers.
mpc math op layers.
"""
"""
from
paddle.fluid.data_feeder
import
check_type_and_dtype
from
..framework
import
MpcVariable
from
..framework
import
MpcVariable
from
..framework
import
check_mpc_variable_and_dtype
from
..mpc_layer_helper
import
MpcLayerHelper
from
..mpc_layer_helper
import
MpcLayerHelper
__all__
=
[
__all__
=
[
...
@@ -39,7 +39,7 @@ def mean(x, name=None):
...
@@ -39,7 +39,7 @@ def mean(x, name=None):
Examples: todo
Examples: todo
"""
"""
helper
=
MpcLayerHelper
(
"mean"
,
**
locals
())
helper
=
MpcLayerHelper
(
"mean"
,
**
locals
())
check_
type_and_dtype
(
x
,
'x'
,
MpcVariable
,
[
'int64'
],
'mean'
)
check_
mpc_variable_and_dtype
(
x
,
'x'
,
[
'int64'
],
'mean'
)
if
name
is
None
:
if
name
is
None
:
out
=
helper
.
create_mpc_variable_for_type_inference
(
dtype
=
x
.
dtype
)
out
=
helper
.
create_mpc_variable_for_type_inference
(
dtype
=
x
.
dtype
)
else
:
else
:
...
@@ -64,7 +64,7 @@ def square(x, name=None):
...
@@ -64,7 +64,7 @@ def square(x, name=None):
Examples: todo
Examples: todo
"""
"""
helper
=
MpcLayerHelper
(
"square"
,
**
locals
())
helper
=
MpcLayerHelper
(
"square"
,
**
locals
())
check_
type_and_dtype
(
x
,
'x'
,
MpcVariable
,
[
'int64'
],
'square'
)
check_
mpc_variable_and_dtype
(
x
,
'x'
,
[
'int64'
],
'square'
)
if
name
is
None
:
if
name
is
None
:
out
=
helper
.
create_mpc_variable_for_type_inference
(
dtype
=
x
.
dtype
)
out
=
helper
.
create_mpc_variable_for_type_inference
(
dtype
=
x
.
dtype
)
else
:
else
:
...
@@ -89,8 +89,7 @@ def sum(x):
...
@@ -89,8 +89,7 @@ def sum(x):
Examples: todo
Examples: todo
"""
"""
helper
=
MpcLayerHelper
(
"sum"
,
**
locals
())
helper
=
MpcLayerHelper
(
"sum"
,
**
locals
())
out
=
helper
.
create_mpc_variable_for_type_inference
(
out
=
helper
.
create_mpc_variable_for_type_inference
(
dtype
=
helper
.
input_dtype
(
'x'
))
dtype
=
helper
.
input_dtype
(
'x'
))
helper
.
append_op
(
helper
.
append_op
(
type
=
"mpc_sum"
,
type
=
"mpc_sum"
,
inputs
=
{
"X"
:
x
},
inputs
=
{
"X"
:
x
},
...
@@ -116,16 +115,14 @@ def square_error_cost(input, label):
...
@@ -116,16 +115,14 @@ def square_error_cost(input, label):
Examples: todo
Examples: todo
"""
"""
helper
=
MpcLayerHelper
(
'square_error_cost'
,
**
locals
())
helper
=
MpcLayerHelper
(
'square_error_cost'
,
**
locals
())
minus_out
=
helper
.
create_mpc_variable_for_type_inference
(
minus_out
=
helper
.
create_mpc_variable_for_type_inference
(
dtype
=
input
.
dtype
)
dtype
=
input
.
dtype
)
helper
.
append_op
(
helper
.
append_op
(
type
=
'mpc_elementwise_sub'
,
type
=
'mpc_elementwise_sub'
,
inputs
=
{
'X'
:
[
input
],
inputs
=
{
'X'
:
[
input
],
'Y'
:
[
label
]},
'Y'
:
[
label
]},
outputs
=
{
'Out'
:
[
minus_out
]})
outputs
=
{
'Out'
:
[
minus_out
]})
square_out
=
helper
.
create_mpc_variable_for_type_inference
(
square_out
=
helper
.
create_mpc_variable_for_type_inference
(
dtype
=
input
.
dtype
)
dtype
=
input
.
dtype
)
helper
.
append_op
(
helper
.
append_op
(
type
=
'mpc_square'
,
type
=
'mpc_square'
,
inputs
=
{
'X'
:
[
minus_out
]},
inputs
=
{
'X'
:
[
minus_out
]},
...
...
python/paddle_fl/mpc/layers/matrix.py
浏览文件 @
a1a9bf6b
# Copyright (c) 20
18
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 20
20
PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -14,12 +14,14 @@
...
@@ -14,12 +14,14 @@
"""
"""
mpc matrix op layers.
mpc matrix op layers.
"""
"""
from
paddle.fluid.data_feeder
import
check_type_and_dtype
from
..framework
import
MpcVariable
from
..framework
import
MpcVariable
from
..framework
import
check_mpc_variable_and_dtype
from
..mpc_layer_helper
import
MpcLayerHelper
from
..mpc_layer_helper
import
MpcLayerHelper
__all__
=
[
'mul'
,
]
__all__
=
[
'mul'
,
]
def
mul
(
x
,
y
,
x_num_col_dims
=
1
,
y_num_col_dims
=
1
,
name
=
None
):
def
mul
(
x
,
y
,
x_num_col_dims
=
1
,
y_num_col_dims
=
1
,
name
=
None
):
...
@@ -66,8 +68,8 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
...
@@ -66,8 +68,8 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
}
}
helper
=
MpcLayerHelper
(
"mul"
,
**
locals
())
helper
=
MpcLayerHelper
(
"mul"
,
**
locals
())
check_
type_and_dtype
(
x
,
'x'
,
MpcVariable
,
[
'int64'
],
'mul'
)
check_
mpc_variable_and_dtype
(
x
,
'x'
,
[
'int64'
],
'mul'
)
check_
type_and_dtype
(
y
,
'y'
,
MpcVariable
,
[
'int64'
],
'mul'
)
check_
mpc_variable_and_dtype
(
y
,
'y'
,
[
'int64'
],
'mul'
)
if
name
is
None
:
if
name
is
None
:
out
=
helper
.
create_mpc_variable_for_type_inference
(
dtype
=
x
.
dtype
)
out
=
helper
.
create_mpc_variable_for_type_inference
(
dtype
=
x
.
dtype
)
else
:
else
:
...
...
python/paddle_fl/mpc/layers/ml.py
浏览文件 @
a1a9bf6b
# Copyright (c) 20
18
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 20
20
PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -17,9 +17,9 @@ mpc ml op layers.
...
@@ -17,9 +17,9 @@ mpc ml op layers.
from
functools
import
reduce
from
functools
import
reduce
from
paddle.fluid.data_feeder
import
check_type
,
check_dtype
from
paddle.fluid.data_feeder
import
check_type
,
check_dtype
from
paddle.fluid.data_feeder
import
check_type_and_dtype
import
numpy
import
numpy
from
..framework
import
MpcVariable
from
..framework
import
MpcVariable
from
..framework
import
check_mpc_variable_and_dtype
from
..mpc_layer_helper
import
MpcLayerHelper
from
..mpc_layer_helper
import
MpcLayerHelper
__all__
=
[
__all__
=
[
...
@@ -30,9 +30,6 @@ __all__ = [
...
@@ -30,9 +30,6 @@ __all__ = [
]
]
# add softmax, relu
def
fc
(
input
,
def
fc
(
input
,
size
,
size
,
num_flatten_dims
=
1
,
num_flatten_dims
=
1
,
...
@@ -186,8 +183,7 @@ def softmax(input, use_cudnn=False, name=None, axis=-1):
...
@@ -186,8 +183,7 @@ def softmax(input, use_cudnn=False, name=None, axis=-1):
"""
"""
attrs
=
{
"axis"
:
axis
,
"use_cudnn"
:
use_cudnn
}
attrs
=
{
"axis"
:
axis
,
"use_cudnn"
:
use_cudnn
}
helper
=
MpcLayerHelper
(
'softmax'
,
**
locals
())
helper
=
MpcLayerHelper
(
'softmax'
,
**
locals
())
check_type_and_dtype
(
input
,
'input'
,
MpcVariable
,
check_mpc_variable_and_dtype
(
input
,
'input'
,
[
'int64'
],
'softmax'
)
[
'float16'
,
'float32'
,
'float64'
],
'softmax'
)
dtype
=
helper
.
input_dtype
()
dtype
=
helper
.
input_dtype
()
mpc_softmax_out
=
helper
.
create_mpc_variable_for_type_inference
(
dtype
)
mpc_softmax_out
=
helper
.
create_mpc_variable_for_type_inference
(
dtype
)
...
@@ -226,7 +222,9 @@ def relu(input, name=None):
...
@@ -226,7 +222,9 @@ def relu(input, name=None):
dtype
=
helper
.
input_dtype
(
input_param_name
=
'input'
)
dtype
=
helper
.
input_dtype
(
input_param_name
=
'input'
)
out
=
helper
.
create_mpc_variable_for_type_inference
(
dtype
)
out
=
helper
.
create_mpc_variable_for_type_inference
(
dtype
)
helper
.
append_op
(
helper
.
append_op
(
type
=
"mpc_relu"
,
inputs
=
{
"X"
:
input
},
outputs
=
{
"Y"
:
out
})
type
=
"mpc_relu"
,
inputs
=
{
"X"
:
input
},
outputs
=
{
"Y"
:
out
})
return
out
return
out
...
...
python/paddle_fl/mpc/layers/mpc_math_op_patch.py
浏览文件 @
a1a9bf6b
...
@@ -32,7 +32,6 @@ def monkey_patch_mpc_variable():
...
@@ -32,7 +32,6 @@ def monkey_patch_mpc_variable():
Monkey patch for operator overloading.
Monkey patch for operator overloading.
:return:
:return:
"""
"""
def
unique_tmp_name
():
def
unique_tmp_name
():
"""
"""
Generate temp name for variable.
Generate temp name for variable.
...
@@ -80,7 +79,9 @@ def monkey_patch_mpc_variable():
...
@@ -80,7 +79,9 @@ def monkey_patch_mpc_variable():
tmp_name
=
unique_tmp_name
()
tmp_name
=
unique_tmp_name
()
return
block
.
create_var
(
name
=
tmp_name
,
dtype
=
dtype
)
return
block
.
create_var
(
name
=
tmp_name
,
dtype
=
dtype
)
def
_elemwise_method_creator_
(
method_name
,
op_type
,
reverse
=
False
):
def
_elemwise_method_creator_
(
method_name
,
op_type
,
reverse
=
False
):
"""
"""
Operator overloading for different method.
Operator overloading for different method.
:param method_name: the name of operator which is overloaded.
:param method_name: the name of operator which is overloaded.
...
@@ -88,19 +89,16 @@ def monkey_patch_mpc_variable():
...
@@ -88,19 +89,16 @@ def monkey_patch_mpc_variable():
:param reverse:
:param reverse:
:return:
:return:
"""
"""
def
__impl__
(
self
,
other_var
):
def
__impl__
(
self
,
other_var
):
lhs_dtype
=
safe_get_dtype
(
self
)
lhs_dtype
=
safe_get_dtype
(
self
)
if
method_name
in
compare_ops
:
if
method_name
in
compare_ops
:
if
not
isinstance
(
other_var
,
Variable
):
if
not
isinstance
(
other_var
,
Variable
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"Unsupported data type of {} for compare operations."
"Unsupported data type of {} for compare operations."
.
format
(
other_var
.
name
))
.
format
(
other_var
.
name
))
else
:
else
:
if
not
isinstance
(
other_var
,
MpcVariable
):
if
not
isinstance
(
other_var
,
MpcVariable
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"Unsupported data type of {}."
.
format
(
other_var
.
name
))
"Unsupported data type of {}."
.
format
(
other_var
.
name
))
rhs_dtype
=
safe_get_dtype
(
other_var
)
rhs_dtype
=
safe_get_dtype
(
other_var
)
if
reverse
:
if
reverse
:
...
@@ -111,8 +109,7 @@ def monkey_patch_mpc_variable():
...
@@ -111,8 +109,7 @@ def monkey_patch_mpc_variable():
if
method_name
in
compare_ops
:
if
method_name
in
compare_ops
:
out
=
create_new_tmp_var
(
current_block
(
self
),
dtype
=
rhs_dtype
)
out
=
create_new_tmp_var
(
current_block
(
self
),
dtype
=
rhs_dtype
)
else
:
else
:
out
=
create_new_tmp_mpc_var
(
out
=
create_new_tmp_mpc_var
(
current_block
(
self
),
dtype
=
lhs_dtype
)
current_block
(
self
),
dtype
=
lhs_dtype
)
# out = create_new_tmp_mpc_var(current_block(self), dtype=lhs_dtype)
# out = create_new_tmp_mpc_var(current_block(self), dtype=lhs_dtype)
...
@@ -179,10 +176,11 @@ def monkey_patch_mpc_variable():
...
@@ -179,10 +176,11 @@ def monkey_patch_mpc_variable():
(
"__lt__"
,
"mpc_less_than"
,
False
),
(
"__lt__"
,
"mpc_less_than"
,
False
),
(
"__le__"
,
"mpc_less_equal"
,
False
),
(
"__le__"
,
"mpc_less_equal"
,
False
),
(
"__gt__"
,
"mpc_greater_than"
,
False
),
(
"__gt__"
,
"mpc_greater_than"
,
False
),
(
"__ge__"
,
"mpc_greater_equal"
,
False
)):
(
"__ge__"
,
"mpc_greater_equal"
,
False
)
):
# Not support computation between MpcVariable and scalar.
# Not support computation between MpcVariable and scalar.
setattr
(
MpcVariable
,
method_name
,
setattr
(
MpcVariable
,
method_name
,
_elemwise_method_creator_
(
method_name
,
op_type
,
reverse
)
_elemwise_method_creator_
(
method_name
,
op_type
,
reverse
)
if
method_name
in
supported_mpc_ops
else
announce_not_impl
)
if
method_name
in
supported_mpc_ops
else
announce_not_impl
)
# MpcVariable.astype = astype
python/setup.py
浏览文件 @
a1a9bf6b
...
@@ -34,7 +34,7 @@ def python_version():
...
@@ -34,7 +34,7 @@ def python_version():
max_version
,
mid_version
,
min_version
=
python_version
()
max_version
,
mid_version
,
min_version
=
python_version
()
REQUIRED_PACKAGES
=
[
REQUIRED_PACKAGES
=
[
'six >= 1.10.0'
,
'protobuf >= 3.1.0'
,
'paddlepaddle
== 1.6.3'
,
'six >= 1.10.0'
,
'protobuf >= 3.1.0'
,
'paddlepaddle
>= 1.8.0'
,
'paddlepaddle-gpu >= 1.8'
]
]
if
max_version
<
3
:
if
max_version
<
3
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录