Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ccfde2da
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ccfde2da
编写于
6月 29, 2022
作者:
Z
Zhen Wang
提交者:
GitHub
6月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update the lock logic used in CinnCompiler::Compile. (#43876)
* Update the lock logic used in CinnCompiler::Compile.
上级
8bd69193
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
65 addition
and
70 deletion
+65
-70
paddle/fluid/framework/paddle2cinn/cinn_compiler.cc
paddle/fluid/framework/paddle2cinn/cinn_compiler.cc
+63
-68
paddle/fluid/framework/paddle2cinn/cinn_compiler.h
paddle/fluid/framework/paddle2cinn/cinn_compiler.h
+2
-2
未找到文件。
paddle/fluid/framework/paddle2cinn/cinn_compiler.cc
浏览文件 @
ccfde2da
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include <iterator>
#include <iterator>
#include <map>
#include <map>
#include <memory>
#include <memory>
#include <mutex>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
...
@@ -43,7 +44,6 @@
...
@@ -43,7 +44,6 @@
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/utils/rw_lock.h"
DECLARE_bool
(
enable_pe_launch_cinn
);
DECLARE_bool
(
enable_pe_launch_cinn
);
DECLARE_bool
(
enable_cinn_auto_tune
);
DECLARE_bool
(
enable_cinn_auto_tune
);
...
@@ -60,66 +60,61 @@ using inference::analysis::Dot;
...
@@ -60,66 +60,61 @@ using inference::analysis::Dot;
using
ir
::
Graph
;
using
ir
::
Graph
;
using
ir
::
Node
;
using
ir
::
Node
;
CinnCompiler
*
CinnCompiler
::
GetInstance
()
{
CinnCompiler
*
CinnCompiler
::
GetInstance
()
{
static
CinnCompiler
*
instance
=
new
CinnCompiler
();
static
CinnCompiler
*
instance
=
new
CinnCompiler
();
return
instance
;
return
instance
;
}
}
const
CinnCompiledObject
&
CinnCompiler
::
Compile
(
const
CinnCompiledObject
&
CinnCompiler
::
Compile
(
const
Graph
&
graph
,
const
Graph
&
graph
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>&
input_tensors
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>
&
input_tensors
,
const
Target
&
target
,
const
Target
&
target
,
void
*
stream
)
{
void
*
stream
)
{
VLOG
(
4
)
<<
"-- The graph to be compiled is:
\n
"
<<
VizGraph
(
graph
);
VLOG
(
4
)
<<
"-- The graph to be compiled is:
\n
"
<<
VizGraph
(
graph
);
CinnCacheKeyByAddress
cur_key_by_address
(
CinnCacheKeyByAddress
cur_key_by_address
(
graph
,
input_tensors
,
target
.
arch_str
());
graph
,
input_tensors
,
target
.
arch_str
());
CinnCacheKeyByStructure
cur_key_by_struct
;
CinnCacheKeyByStructure
cur_key_by_struct
;
bool
exist
=
false
;
if
(
!
cache_by_address_
.
count
(
cur_key_by_address
))
{
{
// generate the structure cache key
phi
::
AutoRDLock
r_guard
{
&
rwlock_
};
cur_key_by_struct
.
SetKey
(
graph
,
input_tensors
,
target
.
arch_str
());
exist
=
cache_by_address_
.
count
(
cur_key_by_address
)
!=
0
;
if
(
!
cache_by_struct_
.
count
(
cur_key_by_struct
))
{
// if cannot find graph by address, checkout whether the graph structure
std
::
int64_t
compiled_num
=
real_compiled_num_
.
fetch_add
(
1
);
// have been stored in cache.
auto
compiled_res
=
if
(
!
exist
)
{
CompileGraph
(
graph
,
input_tensors
,
target
,
compiled_num
,
stream
);
// generate the structure cache key
std
::
unique_lock
<
std
::
mutex
>
guard
(
lock_
);
cur_key_by_struct
.
SetKey
(
graph
,
input_tensors
,
target
.
arch_str
());
// double check cache_by_struct_
if
(
!
cache_by_struct_
.
count
(
cur_key_by_struct
))
{
// if the graph structure can be found, storing the graph address in
cache_by_struct_
[
cur_key_by_struct
]
=
compiled_num
;
// cache for next query.
index2cache_
.
emplace
(
compiled_num
,
std
::
move
(
compiled_res
));
if
(
cache_by_struct_
.
count
(
cur_key_by_struct
)
!=
0
)
{
}
exist
=
true
;
// double check cache_by_address_
if
(
!
cache_by_address_
.
count
(
cur_key_by_address
))
{
cache_by_address_
[
cur_key_by_address
]
=
cache_by_struct_
.
at
(
cur_key_by_struct
);
}
}
else
{
std
::
unique_lock
<
std
::
mutex
>
guard
(
lock_
);
// double check cache_by_address_
if
(
!
cache_by_address_
.
count
(
cur_key_by_address
))
{
cache_by_address_
[
cur_key_by_address
]
=
cache_by_address_
[
cur_key_by_address
]
=
cache_by_struct_
.
at
(
cur_key_by_struct
);
cache_by_struct_
.
at
(
cur_key_by_struct
);
}
}
}
}
}
}
if
(
!
exist
)
{
return
*
index2cache_
.
at
(
cache_by_address_
.
at
(
cur_key_by_address
));
std
::
int64_t
compiled_num
=
real_compiled_num_
.
fetch_add
(
1
);
auto
compiled_res
=
CompileGraph
(
graph
,
input_tensors
,
target
,
compiled_num
,
stream
);
phi
::
AutoWRLock
w_guard
{
&
rwlock_
};
if
(
!
cache_by_struct_
.
count
(
cur_key_by_struct
))
{
cache_by_address_
[
cur_key_by_address
]
=
compiled_num
;
cache_by_struct_
[
cur_key_by_struct
]
=
compiled_num
;
index2cache_
.
emplace
(
compiled_num
,
std
::
move
(
compiled_res
));
}
}
phi
::
AutoRDLock
guard
{
&
rwlock_
};
const
auto
&
cached_boj
=
*
index2cache_
[
cache_by_address_
[
cur_key_by_address
]];
return
cached_boj
;
}
}
const
CinnCompiledObject
&
CinnCompiler
::
Compile
(
const
CinnCompiledObject
&
CinnCompiler
::
Compile
(
int64_t
compilation_key
,
int64_t
compilation_key
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>&
input_tensors
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>
&
input_tensors
,
const
Target
&
target
,
const
Target
&
target
,
void
*
stream
)
{
void
*
stream
)
{
const
auto
&
graph
=
FindGraph
(
compilation_key
);
const
auto
&
graph
=
FindGraph
(
compilation_key
);
return
Compile
(
graph
,
input_tensors
,
target
,
stream
);
return
Compile
(
graph
,
input_tensors
,
target
,
stream
);
}
}
const
CinnCompiledObject
&
CinnCompiler
::
GetCompiledObject
(
const
CinnCompiledObject
&
CinnCompiler
::
GetCompiledObject
(
int64_t
cached_index
)
const
{
int64_t
cached_index
)
const
{
auto
res
=
index2cache_
.
find
(
cached_index
);
auto
res
=
index2cache_
.
find
(
cached_index
);
PADDLE_ENFORCE_NE
(
res
,
PADDLE_ENFORCE_NE
(
res
,
...
@@ -130,7 +125,7 @@ const CinnCompiledObject& CinnCompiler::GetCompiledObject(
...
@@ -130,7 +125,7 @@ const CinnCompiledObject& CinnCompiler::GetCompiledObject(
}
}
int64_t
CinnCompiler
::
AddGraph
(
std
::
unique_ptr
<
Graph
>
graph
)
{
int64_t
CinnCompiler
::
AddGraph
(
std
::
unique_ptr
<
Graph
>
graph
)
{
int64_t
graph_key
=
std
::
hash
<
Graph
*>
()((
&
(
*
graph
)));
int64_t
graph_key
=
std
::
hash
<
Graph
*>
()((
&
(
*
graph
)));
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
graphs_
.
count
(
graph_key
),
graphs_
.
count
(
graph_key
),
0
,
0
,
...
@@ -143,7 +138,7 @@ int64_t CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
...
@@ -143,7 +138,7 @@ int64_t CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
return
graph_key
;
return
graph_key
;
}
}
const
Graph
&
CinnCompiler
::
FindGraph
(
int64_t
graph_key
)
const
{
const
Graph
&
CinnCompiler
::
FindGraph
(
int64_t
graph_key
)
const
{
auto
it
=
graphs_
.
find
(
graph_key
);
auto
it
=
graphs_
.
find
(
graph_key
);
PADDLE_ENFORCE_NE
(
PADDLE_ENFORCE_NE
(
it
,
it
,
...
@@ -155,16 +150,16 @@ const Graph& CinnCompiler::FindGraph(int64_t graph_key) const {
...
@@ -155,16 +150,16 @@ const Graph& CinnCompiler::FindGraph(int64_t graph_key) const {
}
}
std
::
string
CinnCompiler
::
VizGraph
(
int64_t
graph_key
)
const
{
std
::
string
CinnCompiler
::
VizGraph
(
int64_t
graph_key
)
const
{
const
Graph
&
graph
=
FindGraph
(
graph_key
);
const
Graph
&
graph
=
FindGraph
(
graph_key
);
return
VizGraph
(
graph
);
return
VizGraph
(
graph
);
}
}
std
::
string
CinnCompiler
::
VizGraph
(
const
Graph
&
graph
)
const
{
std
::
string
CinnCompiler
::
VizGraph
(
const
Graph
&
graph
)
const
{
Dot
dot
;
Dot
dot
;
std
::
unordered_map
<
const
Node
*
,
std
::
string
>
node2dot
;
std
::
unordered_map
<
const
Node
*
,
std
::
string
>
node2dot
;
int
id
=
0
;
int
id
=
0
;
// Create nodes
// Create nodes
for
(
const
Node
*
n
:
graph
.
Nodes
())
{
for
(
const
Node
*
n
:
graph
.
Nodes
())
{
std
::
string
node_id
=
"Node"
+
std
::
to_string
(
id
++
);
std
::
string
node_id
=
"Node"
+
std
::
to_string
(
id
++
);
if
(
n
->
IsOp
())
{
if
(
n
->
IsOp
())
{
dot
.
AddNode
(
node_id
,
dot
.
AddNode
(
node_id
,
...
@@ -180,7 +175,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
...
@@ -180,7 +175,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
auto
shape
=
n
->
Var
()
->
GetShape
();
auto
shape
=
n
->
Var
()
->
GetShape
();
std
::
vector
<
std
::
string
>
shape_str
(
shape
.
size
());
std
::
vector
<
std
::
string
>
shape_str
(
shape
.
size
());
std
::
transform
(
std
::
transform
(
shape
.
begin
(),
shape
.
end
(),
shape_str
.
begin
(),
[](
const
auto
&
val
)
{
shape
.
begin
(),
shape
.
end
(),
shape_str
.
begin
(),
[](
const
auto
&
val
)
{
return
std
::
to_string
(
val
);
return
std
::
to_string
(
val
);
});
});
label
+=
"
\n
"
+
string
::
join_strings
(
shape_str
,
','
);
label
+=
"
\n
"
+
string
::
join_strings
(
shape_str
,
','
);
...
@@ -198,10 +193,10 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
...
@@ -198,10 +193,10 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
node2dot
[
n
]
=
node_id
;
node2dot
[
n
]
=
node_id
;
}
}
// Create edges
// Create edges
for
(
const
Node
*
n
:
graph
.
Nodes
())
{
for
(
const
Node
*
n
:
graph
.
Nodes
())
{
const
auto
&
src_id
=
node2dot
.
at
(
n
);
const
auto
&
src_id
=
node2dot
.
at
(
n
);
for
(
auto
*
out
:
n
->
outputs
)
{
for
(
auto
*
out
:
n
->
outputs
)
{
const
auto
&
dest_id
=
node2dot
.
at
(
out
);
const
auto
&
dest_id
=
node2dot
.
at
(
out
);
dot
.
AddEdge
(
src_id
,
dest_id
,
{});
dot
.
AddEdge
(
src_id
,
dest_id
,
{});
}
}
}
}
...
@@ -209,7 +204,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
...
@@ -209,7 +204,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
}
}
std
::
string
CinnCompiler
::
SerializeKey
(
int64_t
compilation_key
)
const
{
std
::
string
CinnCompiler
::
SerializeKey
(
int64_t
compilation_key
)
const
{
const
auto
&
graph
=
FindGraph
(
compilation_key
);
const
auto
&
graph
=
FindGraph
(
compilation_key
);
ProgramDesc
program
;
ProgramDesc
program
;
GraphToProgram
(
graph
,
&
program
);
GraphToProgram
(
graph
,
&
program
);
...
@@ -220,7 +215,7 @@ std::string CinnCompiler::SerializeKey(int64_t compilation_key) const {
...
@@ -220,7 +215,7 @@ std::string CinnCompiler::SerializeKey(int64_t compilation_key) const {
}
}
std
::
string
CinnCompiler
::
ReadableKey
(
int64_t
compilation_key
)
const
{
std
::
string
CinnCompiler
::
ReadableKey
(
int64_t
compilation_key
)
const
{
const
auto
&
graph
=
FindGraph
(
compilation_key
);
const
auto
&
graph
=
FindGraph
(
compilation_key
);
ProgramDesc
program
;
ProgramDesc
program
;
GraphToProgram
(
graph
,
&
program
);
GraphToProgram
(
graph
,
&
program
);
...
@@ -230,7 +225,7 @@ std::string CinnCompiler::ReadableKey(int64_t compilation_key) const {
...
@@ -230,7 +225,7 @@ std::string CinnCompiler::ReadableKey(int64_t compilation_key) const {
void
CinnCompiler
::
Clear
()
{
void
CinnCompiler
::
Clear
()
{
{
{
phi
::
AutoWRLock
guard
{
&
rwlock_
}
;
std
::
unique_lock
<
std
::
mutex
>
guard
(
lock_
)
;
graphs_
.
clear
();
graphs_
.
clear
();
cache_by_address_
.
clear
();
cache_by_address_
.
clear
();
cache_by_struct_
.
clear
();
cache_by_struct_
.
clear
();
...
@@ -240,22 +235,22 @@ void CinnCompiler::Clear() {
...
@@ -240,22 +235,22 @@ void CinnCompiler::Clear() {
}
}
void
CinnCompiler
::
CheckCompiledValid
(
void
CinnCompiler
::
CheckCompiledValid
(
const
ir
::
Graph
&
graph
,
const
ir
::
Graph
&
graph
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>&
input_tensors
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>
&
input_tensors
,
const
CinnCompiledObject
&
compiled_obj
)
const
{
const
CinnCompiledObject
&
compiled_obj
)
const
{
const
auto
&
input_var_names
=
graph
.
Get
<
std
::
vector
<
std
::
string
>>
(
kInputVars
);
const
auto
&
input_var_names
=
graph
.
Get
<
std
::
vector
<
std
::
string
>>
(
kInputVars
);
const
auto
&
output_var_names
=
const
auto
&
output_var_names
=
graph
.
Get
<
std
::
vector
<
std
::
string
>>
(
kOutputVars
);
graph
.
Get
<
std
::
vector
<
std
::
string
>>
(
kOutputVars
);
auto
*
launch_context
=
compiled_obj
.
launch_context
.
get
();
auto
*
launch_context
=
compiled_obj
.
launch_context
.
get
();
// 1. check all of the output variables will be assigned by compiled program
// 1. check all of the output variables will be assigned by compiled program
for
(
auto
&&
var_name
:
output_var_names
)
{
for
(
auto
&&
var_name
:
output_var_names
)
{
PADDLE_ENFORCE_EQ
(
launch_context
->
IsVariableUsed
(
var_name
),
PADDLE_ENFORCE_EQ
(
launch_context
->
IsVariableUsed
(
var_name
),
true
,
true
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"Variable(%s) not applied in CINN"
,
var_name
));
"Variable(%s) not applied in CINN"
,
var_name
));
}
}
// 2. check all of the used input variables were correctly deduced by CINN.
// 2. check all of the used input variables were correctly deduced by CINN.
for
(
const
auto
&
var_name
:
input_var_names
)
{
for
(
const
auto
&
var_name
:
input_var_names
)
{
// some input variables were not used by CINN because they were eliminated
// some input variables were not used by CINN because they were eliminated
// by its optimized passes or some operators of it need less inputs
// by its optimized passes or some operators of it need less inputs
if
(
!
launch_context
->
IsVariableUsed
(
var_name
))
{
if
(
!
launch_context
->
IsVariableUsed
(
var_name
))
{
...
@@ -268,11 +263,11 @@ void CinnCompiler::CheckCompiledValid(
...
@@ -268,11 +263,11 @@ void CinnCompiler::CheckCompiledValid(
}
}
std
::
unique_ptr
<
CinnCompiledObject
>
CinnCompiler
::
CompileGraph
(
std
::
unique_ptr
<
CinnCompiledObject
>
CinnCompiler
::
CompileGraph
(
const
ir
::
Graph
&
graph
,
const
ir
::
Graph
&
graph
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>&
input_tensors
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>
&
input_tensors
,
const
Target
&
target
,
const
Target
&
target
,
std
::
int64_t
compiled_num
,
std
::
int64_t
compiled_num
,
void
*
stream
)
const
{
void
*
stream
)
const
{
CinnGraphSymbolization
symbol
{
compiled_num
,
graph
,
target
,
input_tensors
};
CinnGraphSymbolization
symbol
{
compiled_num
,
graph
,
target
,
input_tensors
};
auto
frontend_program
=
symbol
();
auto
frontend_program
=
symbol
();
auto
fetch_ids
=
symbol
.
GetFetchIds
();
auto
fetch_ids
=
symbol
.
GetFetchIds
();
...
...
paddle/fluid/framework/paddle2cinn/cinn_compiler.h
浏览文件 @
ccfde2da
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include <cstdint>
#include <cstdint>
#include <map>
#include <map>
#include <memory>
#include <memory>
#include <mutex>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
...
@@ -26,7 +27,6 @@
...
@@ -26,7 +27,6 @@
#include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/core/utils/rw_lock.h"
namespace
cinn
{
namespace
cinn
{
namespace
common
{
namespace
common
{
...
@@ -129,7 +129,7 @@ class CinnCompiler {
...
@@ -129,7 +129,7 @@ class CinnCompiler {
std
::
unordered_map
<
std
::
int64_t
,
std
::
unique_ptr
<
CinnCompiledObject
>>
std
::
unordered_map
<
std
::
int64_t
,
std
::
unique_ptr
<
CinnCompiledObject
>>
index2cache_
;
index2cache_
;
std
::
atomic_int64_t
real_compiled_num_
{
0
};
std
::
atomic_int64_t
real_compiled_num_
{
0
};
mutable
phi
::
RWLock
rw
lock_
;
mutable
std
::
mutex
lock_
;
DISABLE_COPY_AND_ASSIGN
(
CinnCompiler
);
DISABLE_COPY_AND_ASSIGN
(
CinnCompiler
);
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录