Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleFL
提交
d829309d
P
PaddleFL
项目概览
PaddlePaddle
/
PaddleFL
通知
35
Star
5
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
6
列表
看板
标记
里程碑
合并请求
4
Wiki
3
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleFL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
6
Issue
6
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
3
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d829309d
编写于
8月 31, 2020
作者:
Y
yangqingyou
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
amend according comment
上级
5e248249
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
121 addition
and
102 deletion
+121
-102
core/paddlefl_mpc/mpc_protocol/abstract_context.h
core/paddlefl_mpc/mpc_protocol/abstract_context.h
+18
-67
core/privc/privc_context.h
core/privc/privc_context.h
+13
-13
core/privc3/aby3_context.h
core/privc3/aby3_context.h
+13
-13
core/privc3/boolean_tensor_impl.h
core/privc3/boolean_tensor_impl.h
+10
-9
core/privc3/ot.h
core/privc3/ot.h
+67
-0
未找到文件。
core/paddlefl_mpc/mpc_protocol/abstract_context.h
浏览文件 @
d829309d
...
...
@@ -29,24 +29,23 @@ using PseudorandomNumberGenerator = psi::PseudorandomNumberGenerator;
class
AbstractContext
{
public:
AbstractContext
()
=
default
;
AbstractContext
(
size_t
party
,
std
::
shared_ptr
<
AbstractNetwork
>
network
)
{
init
(
party
,
network
);
};
AbstractContext
(
const
AbstractContext
&
other
)
=
delete
;
AbstractContext
&
operator
=
(
const
AbstractContext
&
other
)
=
delete
;
virtual
void
init
(
size_t
party
,
std
::
shared_ptr
<
AbstractNetwork
>
network
,
block
seed
,
block
seed2
)
=
0
;
void
init
(
size_t
party
,
std
::
shared_ptr
<
AbstractNetwork
>
network
)
{
set_party
(
party
);
set_network
(
network
);
}
void
set_party
(
size_t
party
)
{
PADDLE_ENFORCE_LT
(
party
,
_num_party
,
"party idx should less than %d."
,
_num_party
);
_party
=
party
;
}
void
set_num_party
(
size_t
num_party
)
{
PADDLE_ENFORCE_EQ
(
num_party
==
2
||
num_party
==
3
,
true
,
"2 or 3 party protocol is supported."
);
_num_party
=
num_party
;
}
...
...
@@ -60,35 +59,33 @@ public:
PADDLE_ENFORCE_LE
(
idx
,
_num_party
,
"prng idx should be less and equal to %d."
,
_num_party
);
_prng
[
idx
]
.
set_seed
(
seed
);
get_prng
(
idx
)
.
set_seed
(
seed
);
}
size_t
party
()
const
{
return
_party
;
}
size_t
pre_party
()
const
{
PADDLE_ENFORCE_EQ
(
_num_party
==
2
||
_num_party
==
3
,
true
,
"number of party is not set."
);
return
(
_party
+
_num_party
-
1
)
%
_num_party
;
}
size_t
next_party
()
const
{
PADDLE_ENFORCE_EQ
(
_num_party
==
2
||
_num_party
==
3
,
true
,
"number of party is not set."
);
return
(
_party
+
1
)
%
_num_party
;
}
template
<
typename
T
>
T
gen_random
(
bool
next
)
{
return
_prng
[
next
].
get
<
T
>
();
}
// generate random from prng[0] or prng[1]
// @param next: use bool type for idx 0 or 1
template
<
typename
T
>
T
gen_random
(
bool
next
)
{
return
get_prng
(
next
).
get
<
T
>
();
}
template
<
typename
T
,
template
<
typename
>
class
Tensor
>
void
gen_random
(
Tensor
<
T
>
&
tensor
,
bool
next
)
{
PADDLE_ENFORCE_EQ
(
_num_party
,
3
,
"`gen_random` API is for 3 party protocol."
);
std
::
for_each
(
tensor
.
data
(),
tensor
.
data
()
+
tensor
.
numel
(),
[
this
,
next
](
T
&
val
)
{
val
=
this
->
template
gen_random
<
T
>(
next
);
});
}
template
<
typename
T
>
T
gen_random_private
()
{
return
_prng
[
2
]
.
get
<
T
>
();
}
template
<
typename
T
>
T
gen_random_private
()
{
return
get_prng
(
2
)
.
get
<
T
>
();
}
template
<
typename
T
,
template
<
typename
>
class
Tensor
>
void
gen_random_private
(
Tensor
<
T
>
&
tensor
)
{
...
...
@@ -98,15 +95,11 @@ public:
}
template
<
typename
T
>
T
gen_zero_sharing_arithmetic
()
{
PADDLE_ENFORCE_EQ
(
_num_party
,
3
,
"`gen_zero_sharing_arithmetic` API is for 3 party protocol."
);
return
_prng
[
0
].
get
<
T
>
()
-
_prng
[
1
].
get
<
T
>
();
return
get_prng
(
0
).
get
<
T
>
()
-
get_prng
(
1
).
get
<
T
>
();
}
template
<
typename
T
,
template
<
typename
>
class
Tensor
>
void
gen_zero_sharing_arithmetic
(
Tensor
<
T
>
&
tensor
)
{
PADDLE_ENFORCE_EQ
(
_num_party
,
3
,
"`gen_zero_sharing_arithmetic` API is for 3 party protocol."
);
std
::
for_each
(
tensor
.
data
(),
tensor
.
data
()
+
tensor
.
numel
(),
[
this
](
T
&
val
)
{
val
=
this
->
template
gen_zero_sharing_arithmetic
<
T
>();
...
...
@@ -114,60 +107,18 @@ public:
}
template
<
typename
T
>
T
gen_zero_sharing_boolean
()
{
PADDLE_ENFORCE_EQ
(
_num_party
,
3
,
"`gen_zero_sharing_boolean` API is for 3 party protocol."
);
return
_prng
[
0
].
get
<
T
>
()
^
_prng
[
1
].
get
<
T
>
();
return
get_prng
(
0
).
get
<
T
>
()
^
get_prng
(
1
).
get
<
T
>
();
}
template
<
typename
T
,
template
<
typename
>
class
Tensor
>
void
gen_zero_sharing_boolean
(
Tensor
<
T
>
&
tensor
)
{
PADDLE_ENFORCE_EQ
(
_num_party
,
3
,
"`gen_zero_sharing_boolean` API is for 3 party protocol."
);
std
::
for_each
(
tensor
.
data
(),
tensor
.
data
()
+
tensor
.
numel
(),
[
this
](
T
&
val
)
{
val
=
this
->
template
gen_zero_sharing_boolean
<
T
>();
});
}
template
<
typename
T
,
template
<
typename
>
class
Tensor
>
void
ot
(
size_t
sender
,
size_t
receiver
,
size_t
helper
,
const
Tensor
<
T
>*
choice
,
const
Tensor
<
T
>*
m
[
2
],
Tensor
<
T
>*
buffer
[
2
],
Tensor
<
T
>*
ret
)
{
// TODO: check tensor shape equals
const
size_t
numel
=
buffer
[
0
]
->
numel
();
if
(
party
()
==
sender
)
{
bool
common
=
helper
==
next_party
();
this
->
template
gen_random
(
*
buffer
[
0
],
common
);
this
->
template
gen_random
(
*
buffer
[
1
],
common
);
for
(
size_t
i
=
0
;
i
<
numel
;
++
i
)
{
buffer
[
0
]
->
data
()[
i
]
^=
m
[
0
]
->
data
()[
i
];
buffer
[
1
]
->
data
()[
i
]
^=
m
[
1
]
->
data
()[
i
];
}
network
()
->
template
send
(
receiver
,
*
buffer
[
0
]);
network
()
->
template
send
(
receiver
,
*
buffer
[
1
]);
}
else
if
(
party
()
==
helper
)
{
bool
common
=
sender
==
next_party
();
this
->
template
gen_random
(
*
buffer
[
0
],
common
);
this
->
template
gen_random
(
*
buffer
[
1
],
common
);
for
(
size_t
i
=
0
;
i
<
numel
;
++
i
)
{
buffer
[
0
]
->
data
()[
i
]
=
choice
->
data
()[
i
]
&
1
?
buffer
[
1
]
->
data
()[
i
]
:
buffer
[
0
]
->
data
()[
i
];
}
network
()
->
template
send
(
receiver
,
*
buffer
[
0
]);
}
else
if
(
party
()
==
receiver
)
{
network
()
->
template
recv
(
sender
,
*
buffer
[
0
]);
network
()
->
template
recv
(
sender
,
*
buffer
[
1
]);
network
()
->
template
recv
(
helper
,
*
ret
);
size_t
i
=
0
;
std
::
for_each
(
ret
->
data
(),
ret
->
data
()
+
numel
,
[
&
buffer
,
&
i
,
choice
,
ret
](
T
&
in
)
{
bool
c
=
choice
->
data
()[
i
]
&
1
;
in
^=
buffer
[
c
]
->
data
()[
i
];
++
i
;}
);
}
}
protected:
virtual
PseudorandomNumberGenerator
&
get_prng
(
size_t
idx
)
=
0
;
private:
size_t
_num_party
;
...
...
core/privc/privc_context.h
浏览文件 @
d829309d
...
...
@@ -29,26 +29,26 @@ using AbstractContext = paddle::mpc::AbstractContext;
class
PrivCContext
:
public
AbstractContext
{
public:
PrivCContext
(
size_t
party
,
std
::
shared_ptr
<
AbstractNetwork
>
network
,
const
block
&
seed
=
g_zero_block
)
{
init
(
party
,
network
,
g_zero_block
,
seed
);
block
seed
=
g_zero_block
)
:
AbstractContext
::
AbstractContext
(
party
,
network
)
{
set_num_party
(
2
);
if
(
psi
::
equals
(
seed
,
psi
::
g_zero_block
))
{
seed
=
psi
::
block_from_dev_urandom
();
}
set_random_seed
(
seed
,
0
);
}
PrivCContext
(
const
PrivCContext
&
other
)
=
delete
;
PrivCContext
&
operator
=
(
const
PrivCContext
&
other
)
=
delete
;
void
init
(
size_t
party
,
std
::
shared_ptr
<
AbstractNetwork
>
network
,
block
seed
,
block
seed2
)
override
{
set_num_party
(
2
);
set_party
(
party
);
set_network
(
network
);
if
(
psi
::
equals
(
seed2
,
psi
::
g_zero_block
))
{
seed2
=
psi
::
block_from_dev_urandom
();
}
// seed2 is private
set_random_seed
(
seed2
,
2
);
protected:
PseudorandomNumberGenerator
&
get_prng
(
size_t
idx
)
override
{
return
_prng
;
}
private:
PseudorandomNumberGenerator
_prng
;
};
}
// namespace aby3
core/privc3/aby3_context.h
浏览文件 @
d829309d
...
...
@@ -29,20 +29,10 @@ using AbstractContext = paddle::mpc::AbstractContext;
class
ABY3Context
:
public
AbstractContext
{
public:
ABY3Context
(
size_t
party
,
std
::
shared_ptr
<
AbstractNetwork
>
network
,
const
block
&
seed
=
g_zero_block
,
const
block
&
seed2
=
g_zero_block
)
{
init
(
party
,
network
,
seed
,
seed2
);
}
ABY3Context
(
const
ABY3Context
&
other
)
=
delete
;
ABY3Context
&
operator
=
(
const
ABY3Context
&
other
)
=
delete
;
void
init
(
size_t
party
,
std
::
shared_ptr
<
AbstractNetwork
>
network
,
block
seed
,
block
seed2
)
override
{
block
seed
=
g_zero_block
,
block
seed2
=
g_zero_block
)
:
AbstractContext
::
AbstractContext
(
party
,
network
)
{
set_num_party
(
3
);
set_party
(
party
);
set_network
(
network
);
if
(
psi
::
equals
(
seed
,
psi
::
g_zero_block
))
{
seed
=
psi
::
block_from_dev_urandom
();
...
...
@@ -70,6 +60,16 @@ public:
set_random_seed
(
seed
,
1
);
}
ABY3Context
(
const
ABY3Context
&
other
)
=
delete
;
ABY3Context
&
operator
=
(
const
ABY3Context
&
other
)
=
delete
;
protected:
PseudorandomNumberGenerator
&
get_prng
(
size_t
idx
)
override
{
return
_prng
[
idx
];
}
private:
PseudorandomNumberGenerator
_prng
[
3
];
};
}
// namespace aby3
core/privc3/boolean_tensor_impl.h
浏览文件 @
d829309d
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <algorithm>
#include "core/privc3/ot.h"
namespace
aby3
{
...
...
@@ -432,7 +433,7 @@ void BooleanTensor<T>::mul(const TensorAdapter<T>* rhs,
m
[
0
]
->
add
(
tmp
[
0
],
m
[
0
]);
m
[
1
]
->
add
(
tmp
[
0
],
m
[
1
]);
aby3_ctx
()
->
template
ot
(
idx0
,
idx1
,
idx2
,
null_arg
[
0
],
ObliviousTransfer
::
ot
(
idx0
,
idx1
,
idx2
,
null_arg
[
0
],
const_cast
<
const
aby3
::
TensorAdapter
<
T
>**>
(
m
),
tmp
,
null_arg
[
0
]);
...
...
@@ -445,7 +446,7 @@ void BooleanTensor<T>::mul(const TensorAdapter<T>* rhs,
// ret0 = s1
aby3_ctx
()
->
template
gen_zero_sharing_arithmetic
(
*
(
ret
-
>
mutable_share
(
0
)));
// ret1 = a * b + s0
aby3_ctx
()
->
template
ot
(
idx0
,
idx1
,
idx2
,
share
(
1
),
ObliviousTransfer
::
ot
(
idx0
,
idx1
,
idx2
,
share
(
1
),
const_cast
<
const
aby3
::
TensorAdapter
<
T
>**>
(
null_arg
),
tmp
,
ret
->
mutable_share
(
1
));
aby3_ctx
()
->
network
()
->
template
send
(
idx0
,
*
(
ret
-
>
share
(
0
)));
...
...
@@ -454,7 +455,7 @@ void BooleanTensor<T>::mul(const TensorAdapter<T>* rhs,
// ret0 = a * b + s0
aby3_ctx
()
->
template
gen_zero_sharing_arithmetic
(
*
(
ret
-
>
mutable_share
(
1
)));
// ret1 = s2
aby3_ctx
()
->
template
ot
(
idx0
,
idx1
,
idx2
,
share
(
0
),
ObliviousTransfer
::
ot
(
idx0
,
idx1
,
idx2
,
share
(
0
),
const_cast
<
const
aby3
::
TensorAdapter
<
T
>**>
(
null_arg
),
tmp
,
null_arg
[
0
]);
...
...
core/privc3/ot.h
0 → 100644
浏览文件 @
d829309d
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
namespace
aby3
{
class
ObliviousTransfer
{
public:
template
<
typename
T
,
template
<
typename
>
class
Tensor
>
static
inline
void
ot
(
size_t
sender
,
size_t
receiver
,
size_t
helper
,
const
Tensor
<
T
>*
choice
,
const
Tensor
<
T
>*
m
[
2
],
Tensor
<
T
>*
buffer
[
2
],
Tensor
<
T
>*
ret
)
{
// TODO: check tensor shape equals
auto
aby3_ctx
=
paddle
::
mpc
::
ContextHolder
::
mpc_ctx
();
const
size_t
numel
=
buffer
[
0
]
->
numel
();
if
(
aby3_ctx
->
party
()
==
sender
)
{
bool
common
=
helper
==
aby3_ctx
->
next_party
();
aby3_ctx
->
template
gen_random
(
*
buffer
[
0
],
common
);
aby3_ctx
->
template
gen_random
(
*
buffer
[
1
],
common
);
for
(
size_t
i
=
0
;
i
<
numel
;
++
i
)
{
buffer
[
0
]
->
data
()[
i
]
^=
m
[
0
]
->
data
()[
i
];
buffer
[
1
]
->
data
()[
i
]
^=
m
[
1
]
->
data
()[
i
];
}
aby3_ctx
->
network
()
->
template
send
(
receiver
,
*
buffer
[
0
]);
aby3_ctx
->
network
()
->
template
send
(
receiver
,
*
buffer
[
1
]);
}
else
if
(
aby3_ctx
->
party
()
==
helper
)
{
bool
common
=
sender
==
aby3_ctx
->
next_party
();
aby3_ctx
->
template
gen_random
(
*
buffer
[
0
],
common
);
aby3_ctx
->
template
gen_random
(
*
buffer
[
1
],
common
);
for
(
size_t
i
=
0
;
i
<
numel
;
++
i
)
{
buffer
[
0
]
->
data
()[
i
]
=
choice
->
data
()[
i
]
&
1
?
buffer
[
1
]
->
data
()[
i
]
:
buffer
[
0
]
->
data
()[
i
];
}
aby3_ctx
->
network
()
->
template
send
(
receiver
,
*
buffer
[
0
]);
}
else
if
(
aby3_ctx
->
party
()
==
receiver
)
{
aby3_ctx
->
network
()
->
template
recv
(
sender
,
*
buffer
[
0
]);
aby3_ctx
->
network
()
->
template
recv
(
sender
,
*
buffer
[
1
]);
aby3_ctx
->
network
()
->
template
recv
(
helper
,
*
ret
);
size_t
i
=
0
;
std
::
for_each
(
ret
->
data
(),
ret
->
data
()
+
numel
,
[
&
buffer
,
&
i
,
choice
,
ret
](
T
&
in
)
{
bool
c
=
choice
->
data
()[
i
]
&
1
;
in
^=
buffer
[
c
]
->
data
()[
i
];
++
i
;}
);
}
}
};
}
// namespace aby3
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录