Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9c81a9bb
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9c81a9bb
编写于
11月 03, 2021
作者:
Z
Zeng Jinle
提交者:
GitHub
11月 03, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix PTen thread safety error (#36960)
* fix pten thread safety error * improve coverage
上级
2479664a
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
52 addition
and
47 deletion
+52
-47
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+0
-6
paddle/fluid/framework/pten_utils.cc
paddle/fluid/framework/pten_utils.cc
+50
-16
paddle/fluid/framework/pten_utils.h
paddle/fluid/framework/pten_utils.h
+1
-24
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
9c81a9bb
...
...
@@ -399,7 +399,7 @@ cc_library(save_load_util SRCS save_load_util.cc DEPS tensor scope layer)
cc_test
(
save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer
)
cc_library
(
generator SRCS generator.cc DEPS enforce place
)
cc_library
(
pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows place pten var_type_traits pten_hapi_utils
)
cc_library
(
pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows place pten var_type_traits pten_hapi_utils
op_info
)
# Get the current working branch
execute_process
(
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
9c81a9bb
...
...
@@ -1762,12 +1762,6 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
KernelSignature
OperatorWithKernel
::
GetExpectedPtenKernelArgs
(
const
ExecutionContext
&
ctx
)
const
{
if
(
!
KernelSignatureMap
::
Instance
().
Has
(
Type
()))
{
// TODO(chenweihang): we can generate this map by proto info in compile time
KernelArgsNameMakerByOpProto
maker
(
Info
().
proto_
);
KernelSignatureMap
::
Instance
().
Emplace
(
Type
(),
std
::
move
(
maker
.
GetKernelSignature
()));
}
return
KernelSignatureMap
::
Instance
().
Get
(
Type
());
}
...
...
paddle/fluid/framework/pten_utils.cc
浏览文件 @
9c81a9bb
...
...
@@ -15,8 +15,10 @@ limitations under the License. */
#include <sstream>
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/pten/core/kernel_factory.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/string/string_helper.h"
...
...
@@ -24,6 +26,34 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
class
KernelArgsNameMakerByOpProto
:
public
KernelArgsNameMaker
{
public:
explicit
KernelArgsNameMakerByOpProto
(
const
framework
::
proto
::
OpProto
*
op_proto
)
:
op_proto_
(
op_proto
)
{
PADDLE_ENFORCE_NOT_NULL
(
op_proto_
,
platform
::
errors
::
InvalidArgument
(
"Op proto cannot be nullptr."
));
}
~
KernelArgsNameMakerByOpProto
()
{}
const
paddle
::
SmallVector
<
std
::
string
>&
GetInputArgsNames
()
override
;
const
paddle
::
SmallVector
<
std
::
string
>&
GetOutputArgsNames
()
override
;
const
paddle
::
SmallVector
<
std
::
string
>&
GetAttrsArgsNames
()
override
;
KernelSignature
GetKernelSignature
();
private:
DISABLE_COPY_AND_ASSIGN
(
KernelArgsNameMakerByOpProto
);
private:
const
framework
::
proto
::
OpProto
*
op_proto_
;
paddle
::
SmallVector
<
std
::
string
>
input_names_
;
paddle
::
SmallVector
<
std
::
string
>
output_names_
;
paddle
::
SmallVector
<
std
::
string
>
attr_names_
;
};
OpKernelType
TransPtenKernelKeyToOpKernelType
(
const
pten
::
KernelKey
&
kernel_key
)
{
proto
::
VarType
::
Type
data_type
=
...
...
@@ -60,15 +90,29 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey(
}
KernelSignatureMap
*
KernelSignatureMap
::
kernel_signature_map_
=
nullptr
;
std
::
mutex
KernelSignatureMap
::
mutex
_
;
std
::
once_flag
KernelSignatureMap
::
init_flag
_
;
KernelSignatureMap
&
KernelSignatureMap
::
Instance
()
{
if
(
kernel_signature_map_
==
nullptr
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
kernel_signature_map_
==
nullptr
)
{
kernel_signature_map_
=
new
KernelSignatureMap
;
std
::
call_once
(
init_flag_
,
[]
{
kernel_signature_map_
=
new
KernelSignatureMap
();
for
(
const
auto
&
pair
:
OpInfoMap
::
Instance
().
map
())
{
const
auto
&
op_type
=
pair
.
first
;
const
auto
*
op_proto
=
pair
.
second
.
proto_
;
if
(
pten
::
KernelFactory
::
Instance
().
HasCompatiblePtenKernel
(
op_type
))
{
KernelArgsNameMakerByOpProto
maker
(
op_proto
);
VLOG
(
10
)
<<
"Register kernel signature for "
<<
op_type
;
auto
success
=
kernel_signature_map_
->
map_
.
emplace
(
op_type
,
std
::
move
(
maker
.
GetKernelSignature
()))
.
second
;
PADDLE_ENFORCE_EQ
(
success
,
true
,
platform
::
errors
::
PermissionDenied
(
"Kernel signature of the operator %s has been registered."
,
op_type
));
}
}
}
}
);
return
*
kernel_signature_map_
;
}
...
...
@@ -76,16 +120,6 @@ bool KernelSignatureMap::Has(const std::string& op_type) const {
return
map_
.
find
(
op_type
)
!=
map_
.
end
();
}
void
KernelSignatureMap
::
Emplace
(
const
std
::
string
&
op_type
,
KernelSignature
&&
signature
)
{
if
(
!
Has
(
op_type
))
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
!
Has
(
op_type
))
{
map_
.
emplace
(
op_type
,
signature
);
}
}
}
const
KernelSignature
&
KernelSignatureMap
::
Get
(
const
std
::
string
&
op_type
)
const
{
auto
it
=
map_
.
find
(
op_type
);
...
...
paddle/fluid/framework/pten_utils.h
浏览文件 @
9c81a9bb
...
...
@@ -67,8 +67,6 @@ class KernelSignatureMap {
bool
Has
(
const
std
::
string
&
op_type
)
const
;
void
Emplace
(
const
std
::
string
&
op_type
,
KernelSignature
&&
signature
);
const
KernelSignature
&
Get
(
const
std
::
string
&
op_type
)
const
;
private:
...
...
@@ -77,7 +75,7 @@ class KernelSignatureMap {
private:
static
KernelSignatureMap
*
kernel_signature_map_
;
static
std
::
mutex
mutex
_
;
static
std
::
once_flag
init_flag
_
;
paddle
::
flat_hash_map
<
std
::
string
,
KernelSignature
>
map_
;
};
...
...
@@ -90,27 +88,6 @@ class KernelArgsNameMaker {
virtual
const
paddle
::
SmallVector
<
std
::
string
>&
GetAttrsArgsNames
()
=
0
;
};
class
KernelArgsNameMakerByOpProto
:
public
KernelArgsNameMaker
{
public:
explicit
KernelArgsNameMakerByOpProto
(
framework
::
proto
::
OpProto
*
op_proto
)
:
op_proto_
(
op_proto
)
{}
~
KernelArgsNameMakerByOpProto
()
{}
const
paddle
::
SmallVector
<
std
::
string
>&
GetInputArgsNames
()
override
;
const
paddle
::
SmallVector
<
std
::
string
>&
GetOutputArgsNames
()
override
;
const
paddle
::
SmallVector
<
std
::
string
>&
GetAttrsArgsNames
()
override
;
KernelSignature
GetKernelSignature
();
private:
framework
::
proto
::
OpProto
*
op_proto_
;
paddle
::
SmallVector
<
std
::
string
>
input_names_
;
paddle
::
SmallVector
<
std
::
string
>
output_names_
;
paddle
::
SmallVector
<
std
::
string
>
attr_names_
;
};
std
::
string
KernelSignatureToString
(
const
KernelSignature
&
signature
);
}
// namespace framework
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录