Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ccfde2da
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ccfde2da
编写于
6月 29, 2022
作者:
Z
Zhen Wang
提交者:
GitHub
6月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update the lock logic used in CinnCompiler::Compile. (#43876)
* Update the lock logic used in CinnCompiler::Compile.
上级
8bd69193
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
65 addition
and
70 deletion
+65
-70
paddle/fluid/framework/paddle2cinn/cinn_compiler.cc
paddle/fluid/framework/paddle2cinn/cinn_compiler.cc
+63
-68
paddle/fluid/framework/paddle2cinn/cinn_compiler.h
paddle/fluid/framework/paddle2cinn/cinn_compiler.h
+2
-2
未找到文件。
paddle/fluid/framework/paddle2cinn/cinn_compiler.cc
浏览文件 @
ccfde2da
...
...
@@ -18,6 +18,7 @@
#include <iterator>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
...
...
@@ -43,7 +44,6 @@
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/utils/rw_lock.h"
DECLARE_bool
(
enable_pe_launch_cinn
);
DECLARE_bool
(
enable_cinn_auto_tune
);
...
...
@@ -60,66 +60,61 @@ using inference::analysis::Dot;
using
ir
::
Graph
;
using
ir
::
Node
;
CinnCompiler
*
CinnCompiler
::
GetInstance
()
{
static
CinnCompiler
*
instance
=
new
CinnCompiler
();
CinnCompiler
*
CinnCompiler
::
GetInstance
()
{
static
CinnCompiler
*
instance
=
new
CinnCompiler
();
return
instance
;
}
const
CinnCompiledObject
&
CinnCompiler
::
Compile
(
const
Graph
&
graph
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>&
input_tensors
,
const
Target
&
target
,
void
*
stream
)
{
const
CinnCompiledObject
&
CinnCompiler
::
Compile
(
const
Graph
&
graph
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>
&
input_tensors
,
const
Target
&
target
,
void
*
stream
)
{
VLOG
(
4
)
<<
"-- The graph to be compiled is:
\n
"
<<
VizGraph
(
graph
);
CinnCacheKeyByAddress
cur_key_by_address
(
graph
,
input_tensors
,
target
.
arch_str
());
CinnCacheKeyByStructure
cur_key_by_struct
;
bool
exist
=
false
;
{
phi
::
AutoRDLock
r_guard
{
&
rwlock_
};
exist
=
cache_by_address_
.
count
(
cur_key_by_address
)
!=
0
;
// if cannot find graph by address, checkout whether the graph structure
// have been stored in cache.
if
(
!
exist
)
{
// generate the structure cache key
cur_key_by_struct
.
SetKey
(
graph
,
input_tensors
,
target
.
arch_str
());
// if the graph structure can be found, storing the graph address in
// cache for next query.
if
(
cache_by_struct_
.
count
(
cur_key_by_struct
)
!=
0
)
{
exist
=
true
;
if
(
!
cache_by_address_
.
count
(
cur_key_by_address
))
{
// generate the structure cache key
cur_key_by_struct
.
SetKey
(
graph
,
input_tensors
,
target
.
arch_str
());
if
(
!
cache_by_struct_
.
count
(
cur_key_by_struct
))
{
std
::
int64_t
compiled_num
=
real_compiled_num_
.
fetch_add
(
1
);
auto
compiled_res
=
CompileGraph
(
graph
,
input_tensors
,
target
,
compiled_num
,
stream
);
std
::
unique_lock
<
std
::
mutex
>
guard
(
lock_
);
// double check cache_by_struct_
if
(
!
cache_by_struct_
.
count
(
cur_key_by_struct
))
{
cache_by_struct_
[
cur_key_by_struct
]
=
compiled_num
;
index2cache_
.
emplace
(
compiled_num
,
std
::
move
(
compiled_res
));
}
// double check cache_by_address_
if
(
!
cache_by_address_
.
count
(
cur_key_by_address
))
{
cache_by_address_
[
cur_key_by_address
]
=
cache_by_struct_
.
at
(
cur_key_by_struct
);
}
}
else
{
std
::
unique_lock
<
std
::
mutex
>
guard
(
lock_
);
// double check cache_by_address_
if
(
!
cache_by_address_
.
count
(
cur_key_by_address
))
{
cache_by_address_
[
cur_key_by_address
]
=
cache_by_struct_
.
at
(
cur_key_by_struct
);
}
}
}
if
(
!
exist
)
{
std
::
int64_t
compiled_num
=
real_compiled_num_
.
fetch_add
(
1
);
auto
compiled_res
=
CompileGraph
(
graph
,
input_tensors
,
target
,
compiled_num
,
stream
);
phi
::
AutoWRLock
w_guard
{
&
rwlock_
};
if
(
!
cache_by_struct_
.
count
(
cur_key_by_struct
))
{
cache_by_address_
[
cur_key_by_address
]
=
compiled_num
;
cache_by_struct_
[
cur_key_by_struct
]
=
compiled_num
;
index2cache_
.
emplace
(
compiled_num
,
std
::
move
(
compiled_res
));
}
}
phi
::
AutoRDLock
guard
{
&
rwlock_
};
const
auto
&
cached_boj
=
*
index2cache_
[
cache_by_address_
[
cur_key_by_address
]];
return
cached_boj
;
return
*
index2cache_
.
at
(
cache_by_address_
.
at
(
cur_key_by_address
));
}
const
CinnCompiledObject
&
CinnCompiler
::
Compile
(
const
CinnCompiledObject
&
CinnCompiler
::
Compile
(
int64_t
compilation_key
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>&
input_tensors
,
const
Target
&
target
,
void
*
stream
)
{
const
auto
&
graph
=
FindGraph
(
compilation_key
);
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>
&
input_tensors
,
const
Target
&
target
,
void
*
stream
)
{
const
auto
&
graph
=
FindGraph
(
compilation_key
);
return
Compile
(
graph
,
input_tensors
,
target
,
stream
);
}
const
CinnCompiledObject
&
CinnCompiler
::
GetCompiledObject
(
const
CinnCompiledObject
&
CinnCompiler
::
GetCompiledObject
(
int64_t
cached_index
)
const
{
auto
res
=
index2cache_
.
find
(
cached_index
);
PADDLE_ENFORCE_NE
(
res
,
...
...
@@ -130,7 +125,7 @@ const CinnCompiledObject& CinnCompiler::GetCompiledObject(
}
int64_t
CinnCompiler
::
AddGraph
(
std
::
unique_ptr
<
Graph
>
graph
)
{
int64_t
graph_key
=
std
::
hash
<
Graph
*>
()((
&
(
*
graph
)));
int64_t
graph_key
=
std
::
hash
<
Graph
*>
()((
&
(
*
graph
)));
PADDLE_ENFORCE_EQ
(
graphs_
.
count
(
graph_key
),
0
,
...
...
@@ -143,7 +138,7 @@ int64_t CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
return
graph_key
;
}
const
Graph
&
CinnCompiler
::
FindGraph
(
int64_t
graph_key
)
const
{
const
Graph
&
CinnCompiler
::
FindGraph
(
int64_t
graph_key
)
const
{
auto
it
=
graphs_
.
find
(
graph_key
);
PADDLE_ENFORCE_NE
(
it
,
...
...
@@ -155,16 +150,16 @@ const Graph& CinnCompiler::FindGraph(int64_t graph_key) const {
}
std
::
string
CinnCompiler
::
VizGraph
(
int64_t
graph_key
)
const
{
const
Graph
&
graph
=
FindGraph
(
graph_key
);
const
Graph
&
graph
=
FindGraph
(
graph_key
);
return
VizGraph
(
graph
);
}
std
::
string
CinnCompiler
::
VizGraph
(
const
Graph
&
graph
)
const
{
std
::
string
CinnCompiler
::
VizGraph
(
const
Graph
&
graph
)
const
{
Dot
dot
;
std
::
unordered_map
<
const
Node
*
,
std
::
string
>
node2dot
;
std
::
unordered_map
<
const
Node
*
,
std
::
string
>
node2dot
;
int
id
=
0
;
// Create nodes
for
(
const
Node
*
n
:
graph
.
Nodes
())
{
for
(
const
Node
*
n
:
graph
.
Nodes
())
{
std
::
string
node_id
=
"Node"
+
std
::
to_string
(
id
++
);
if
(
n
->
IsOp
())
{
dot
.
AddNode
(
node_id
,
...
...
@@ -180,7 +175,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
auto
shape
=
n
->
Var
()
->
GetShape
();
std
::
vector
<
std
::
string
>
shape_str
(
shape
.
size
());
std
::
transform
(
shape
.
begin
(),
shape
.
end
(),
shape_str
.
begin
(),
[](
const
auto
&
val
)
{
shape
.
begin
(),
shape
.
end
(),
shape_str
.
begin
(),
[](
const
auto
&
val
)
{
return
std
::
to_string
(
val
);
});
label
+=
"
\n
"
+
string
::
join_strings
(
shape_str
,
','
);
...
...
@@ -198,10 +193,10 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
node2dot
[
n
]
=
node_id
;
}
// Create edges
for
(
const
Node
*
n
:
graph
.
Nodes
())
{
const
auto
&
src_id
=
node2dot
.
at
(
n
);
for
(
auto
*
out
:
n
->
outputs
)
{
const
auto
&
dest_id
=
node2dot
.
at
(
out
);
for
(
const
Node
*
n
:
graph
.
Nodes
())
{
const
auto
&
src_id
=
node2dot
.
at
(
n
);
for
(
auto
*
out
:
n
->
outputs
)
{
const
auto
&
dest_id
=
node2dot
.
at
(
out
);
dot
.
AddEdge
(
src_id
,
dest_id
,
{});
}
}
...
...
@@ -209,7 +204,7 @@ std::string CinnCompiler::VizGraph(const Graph& graph) const {
}
std
::
string
CinnCompiler
::
SerializeKey
(
int64_t
compilation_key
)
const
{
const
auto
&
graph
=
FindGraph
(
compilation_key
);
const
auto
&
graph
=
FindGraph
(
compilation_key
);
ProgramDesc
program
;
GraphToProgram
(
graph
,
&
program
);
...
...
@@ -220,7 +215,7 @@ std::string CinnCompiler::SerializeKey(int64_t compilation_key) const {
}
std
::
string
CinnCompiler
::
ReadableKey
(
int64_t
compilation_key
)
const
{
const
auto
&
graph
=
FindGraph
(
compilation_key
);
const
auto
&
graph
=
FindGraph
(
compilation_key
);
ProgramDesc
program
;
GraphToProgram
(
graph
,
&
program
);
...
...
@@ -230,7 +225,7 @@ std::string CinnCompiler::ReadableKey(int64_t compilation_key) const {
void
CinnCompiler
::
Clear
()
{
{
phi
::
AutoWRLock
guard
{
&
rwlock_
}
;
std
::
unique_lock
<
std
::
mutex
>
guard
(
lock_
)
;
graphs_
.
clear
();
cache_by_address_
.
clear
();
cache_by_struct_
.
clear
();
...
...
@@ -240,22 +235,22 @@ void CinnCompiler::Clear() {
}
void
CinnCompiler
::
CheckCompiledValid
(
const
ir
::
Graph
&
graph
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>&
input_tensors
,
const
CinnCompiledObject
&
compiled_obj
)
const
{
const
auto
&
input_var_names
=
graph
.
Get
<
std
::
vector
<
std
::
string
>>
(
kInputVars
);
const
auto
&
output_var_names
=
const
ir
::
Graph
&
graph
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>
&
input_tensors
,
const
CinnCompiledObject
&
compiled_obj
)
const
{
const
auto
&
input_var_names
=
graph
.
Get
<
std
::
vector
<
std
::
string
>>
(
kInputVars
);
const
auto
&
output_var_names
=
graph
.
Get
<
std
::
vector
<
std
::
string
>>
(
kOutputVars
);
auto
*
launch_context
=
compiled_obj
.
launch_context
.
get
();
auto
*
launch_context
=
compiled_obj
.
launch_context
.
get
();
// 1. check all of the output variables will be assigned by compiled program
for
(
auto
&&
var_name
:
output_var_names
)
{
for
(
auto
&&
var_name
:
output_var_names
)
{
PADDLE_ENFORCE_EQ
(
launch_context
->
IsVariableUsed
(
var_name
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"Variable(%s) not applied in CINN"
,
var_name
));
}
// 2. check all of the used input variables were correctly deduced by CINN.
for
(
const
auto
&
var_name
:
input_var_names
)
{
for
(
const
auto
&
var_name
:
input_var_names
)
{
// some input variables were not used by CINN because they were eliminated
// by its optimized passes or some operators of it need less inputs
if
(
!
launch_context
->
IsVariableUsed
(
var_name
))
{
...
...
@@ -268,11 +263,11 @@ void CinnCompiler::CheckCompiledValid(
}
std
::
unique_ptr
<
CinnCompiledObject
>
CinnCompiler
::
CompileGraph
(
const
ir
::
Graph
&
graph
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>&
input_tensors
,
const
Target
&
target
,
const
ir
::
Graph
&
graph
,
const
std
::
map
<
std
::
string
,
const
LoDTensor
*>
&
input_tensors
,
const
Target
&
target
,
std
::
int64_t
compiled_num
,
void
*
stream
)
const
{
void
*
stream
)
const
{
CinnGraphSymbolization
symbol
{
compiled_num
,
graph
,
target
,
input_tensors
};
auto
frontend_program
=
symbol
();
auto
fetch_ids
=
symbol
.
GetFetchIds
();
...
...
paddle/fluid/framework/paddle2cinn/cinn_compiler.h
浏览文件 @
ccfde2da
...
...
@@ -18,6 +18,7 @@
#include <cstdint>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
...
...
@@ -26,7 +27,6 @@
#include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/core/utils/rw_lock.h"
namespace
cinn
{
namespace
common
{
...
...
@@ -129,7 +129,7 @@ class CinnCompiler {
std
::
unordered_map
<
std
::
int64_t
,
std
::
unique_ptr
<
CinnCompiledObject
>>
index2cache_
;
std
::
atomic_int64_t
real_compiled_num_
{
0
};
mutable
phi
::
RWLock
rw
lock_
;
mutable
std
::
mutex
lock_
;
DISABLE_COPY_AND_ASSIGN
(
CinnCompiler
);
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录