Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e2c1b7c3
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e2c1b7c3
编写于
6月 06, 2019
作者:
B
baojun
提交者:
tensor-tang
6月 06, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[NGraph] cache compiled function instead test=develop (#17845)
上级
d008260f
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
81 addition
and
75 deletion
+81
-75
paddle/fluid/operators/ngraph/ngraph_engine.cc
paddle/fluid/operators/ngraph/ngraph_engine.cc
+76
-71
paddle/fluid/operators/ngraph/ngraph_engine.h
paddle/fluid/operators/ngraph/ngraph_engine.h
+5
-4
未找到文件。
paddle/fluid/operators/ngraph/ngraph_engine.cc
浏览文件 @
e2c1b7c3
...
...
@@ -471,11 +471,11 @@ void NgraphEngine::BuildNgNodes() {
}
}
void
NgraphEngine
::
BuildNgFunction
(
const
framework
::
ExecutionContext
&
ctx
)
{
std
::
shared_ptr
<
ngraph
::
Function
>
NgraphEngine
::
BuildNgFunction
(
const
framework
::
ExecutionContext
&
ctx
)
{
Prepare
(
ctx
);
GetNgInputShape
();
BuildNgNodes
();
ngraph_function_
=
nullptr
;
ngraph
::
NodeVector
func_outputs
;
ngraph
::
ParameterVector
func_inputs
;
...
...
@@ -490,99 +490,105 @@ void NgraphEngine::BuildNgFunction(const framework::ExecutionContext& ctx) {
func_inputs
.
emplace_back
(
prm
);
}
ngraph_function_
=
std
::
make_shared
<
ngraph
::
Function
>
(
func_outputs
,
func_inputs
);
return
std
::
make_shared
<
ngraph
::
Function
>
(
func_outputs
,
func_inputs
);
}
void
NgraphEngine
::
ClearNgCache
()
{
auto
it
=
engine_cache
.
begin
();
while
(
it
!=
engine_cache
.
end
())
{
auto
ng_engine
=
it
->
second
;
backend_
->
remove_compiled_function
(
ng_engine
.
ngraph_handle
);
++
it
;
}
engine_cache
.
clear
();
auto
it_tensor
=
t_in_cache_
.
begin
();
while
(
it_tensor
!=
t_in_cache_
.
end
())
{
auto
t_vec
=
it_tensor
->
second
;
for
(
auto
t_in
:
t_vec
)
{
t_in
.
reset
();
}
++
it_tensor
;
}
t_in_cache_
.
clear
();
}
void
NgraphEngine
::
GetNgFunction
(
const
framework
::
ExecutionContext
&
ctx
)
{
auto
interval
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"interval"
);
std
::
string
engine_key
=
ctx
.
Attr
<
std
::
string
>
(
"engine_key"
);
// set to flase, to debug cache or recompile everytime.
bool
use_cache
=
true
;
if
(
use_cache
)
{
this
->
func_cache_key_
=
""
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
feed_vars
.
size
());
++
i
)
{
auto
*
var
=
scope_
.
FindVar
(
feed_vars
[
i
]);
if
(
var
&&
var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
*
tensor_pd
=
GetLoDTensorOrSelectedRowsValueFromVar
(
*
var
);
auto
dims
=
tensor_pd
->
dims
(
);
for
(
int
j
=
0
;
j
<
dims
.
size
();
++
j
)
{
func_cache_key_
+=
std
::
to_string
(
dims
[
j
]);
}
if
(
!
use_cache
)
ClearNgCache
();
this
->
func_cache_key_
=
""
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
feed_vars
.
size
());
++
i
)
{
auto
*
var
=
scope_
.
FindVar
(
feed_vars
[
i
]);
if
(
var
&&
var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
*
tensor_pd
=
GetLoDTensorOrSelectedRowsValueFromVar
(
*
var
);
auto
dims
=
tensor_pd
->
dims
();
for
(
int
j
=
0
;
j
<
dims
.
size
();
++
j
)
{
func_cache_key_
+=
std
::
to_string
(
dims
[
j
]);
}
}
func_cache_key_
+=
std
::
to_string
(
interval
[
0
])
+
"_"
+
std
::
to_string
(
interval
[
1
])
+
engine_key
;
func_cache_key_
=
std
::
to_string
(
std
::
hash
<
std
::
string
>
()(
func_cache_key_
));
if
(
engine_cache
.
find
(
func_cache_key_
)
!=
engine_cache
.
end
())
{
if
(
engine_cache
[
func_cache_key_
].
persistables
.
size
()
==
0
)
{
engine_cache
.
clear
();
t_in_cache_
.
clear
();
}
else
{
auto
var_name
=
engine_cache
[
func_cache_key_
].
persistables
.
begin
();
framework
::
Variable
*
var
=
scope_
.
FindVar
(
*
var_name
);
if
(
var
!=
pre_var_ptr
)
{
engine_cache
.
clear
();
t_in_cache_
.
clear
();
}
pre_var_ptr
=
var
;
}
func_cache_key_
+=
std
::
to_string
(
interval
[
0
])
+
"_"
+
std
::
to_string
(
interval
[
1
])
+
engine_key
;
func_cache_key_
=
std
::
to_string
(
std
::
hash
<
std
::
string
>
()(
func_cache_key_
));
if
(
engine_cache
.
find
(
func_cache_key_
)
!=
engine_cache
.
end
())
{
if
(
engine_cache
[
func_cache_key_
].
persistables
.
size
()
==
0
)
{
ClearNgCache
();
}
else
{
auto
var_name
=
engine_cache
[
func_cache_key_
].
persistables
.
begin
();
framework
::
Variable
*
var
=
scope_
.
FindVar
(
*
var_name
);
if
(
var
!=
pre_var_ptr
)
{
ClearNgCache
();
}
pre_var_ptr
=
var
;
}
}
if
(
engine_cache
.
find
(
func_cache_key_
)
==
engine_cache
.
end
())
{
BuildNgFunction
(
ctx
);
engine_cache
[
func_cache_key_
].
ngraph_function
=
this
->
ngraph_function_
;
engine_cache
[
func_cache_key_
].
persistables
=
this
->
persistables_
;
engine_cache
[
func_cache_key_
].
var_in_updates
=
this
->
var_in_updates_
;
engine_cache
[
func_cache_key_
].
var_in
=
this
->
var_in_
;
engine_cache
[
func_cache_key_
].
var_out
=
this
->
var_out_
;
engine_cache
[
func_cache_key_
].
is_test
=
this
->
is_test_
;
if
(
engine_cache
.
find
(
func_cache_key_
)
==
engine_cache
.
end
())
{
if
(
engine_cache
.
size
()
>
5
)
ClearNgCache
();
auto
func
=
BuildNgFunction
(
ctx
);
// Due to optimization backend may produce results in other layouts,
// make sure we get default layout for results.
for
(
auto
&
r
:
func
->
get_results
())
{
r
->
set_needs_default_layout
(
true
);
}
}
else
{
BuildNgFunction
(
ctx
);
engine_cache
[
func_cache_key_
].
ngraph_handle
=
backend_
->
compile
(
func
);
engine_cache
[
func_cache_key_
].
persistables
=
this
->
persistables_
;
engine_cache
[
func_cache_key_
].
var_in_updates
=
this
->
var_in_updates_
;
engine_cache
[
func_cache_key_
].
var_in
=
this
->
var_in_
;
engine_cache
[
func_cache_key_
].
var_out
=
this
->
var_out_
;
engine_cache
[
func_cache_key_
].
is_test
=
this
->
is_test_
;
}
}
void
NgraphEngine
::
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
std
::
shared_ptr
<
ngraph
::
Function
>
ng_func
;
std
::
shared_ptr
<
ngraph
::
runtime
::
Executable
>
ng_handle
;
const
std
::
set
<
std
::
string
>*
p_persistables
;
const
std
::
vector
<
size_t
>*
p_var_in_updates
;
const
std
::
vector
<
std
::
string
>*
p_var_in
;
const
std
::
vector
<
std
::
string
>*
p_var_out
;
bool
is_test
;
bool
use_cache
=
true
;
if
(
use_cache
)
{
PADDLE_ENFORCE
(
engine_cache
.
find
(
func_cache_key_
)
!=
engine_cache
.
end
(),
"Cannot find cached data to run ngraph function"
);
ng_func
=
engine_cache
[
func_cache_key_
].
ngraph_function
;
p_persistables
=
&
(
engine_cache
[
func_cache_key_
].
persistables
);
p_var_in_updates
=
&
(
engine_cache
[
func_cache_key_
].
var_in_updates
);
p_var_in
=
&
(
engine_cache
[
func_cache_key_
].
var_in
);
p_var_out
=
&
(
engine_cache
[
func_cache_key_
].
var_out
);
is_test
=
engine_cache
[
func_cache_key_
].
is_test
;
}
else
{
ng_func
=
ngraph_function_
;
p_persistables
=
&
this
->
persistables_
;
p_var_in_updates
=
&
this
->
var_in_updates_
;
p_var_in
=
&
this
->
var_in_
;
p_var_out
=
&
this
->
var_out_
;
is_test
=
this
->
is_test_
;
}
PADDLE_ENFORCE
(
engine_cache
.
find
(
func_cache_key_
)
!=
engine_cache
.
end
(),
"Cannot find cached data to run ngraph function"
);
ng_handle
=
engine_cache
[
func_cache_key_
].
ngraph_handle
;
p_persistables
=
&
(
engine_cache
[
func_cache_key_
].
persistables
);
p_var_in_updates
=
&
(
engine_cache
[
func_cache_key_
].
var_in_updates
);
p_var_in
=
&
(
engine_cache
[
func_cache_key_
].
var_in
);
p_var_out
=
&
(
engine_cache
[
func_cache_key_
].
var_out
);
is_test
=
engine_cache
[
func_cache_key_
].
is_test
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Tensor
>>*
p_t_in
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Tensor
>>
t_in
=
{};
auto
m_parameters
=
ng_func
->
get_parameters
();
auto
m_results
=
ng_func
->
get_results
();
// Due to optimization backend may produce results in other layouts,
// make sure we get default layout for results.
for
(
auto
&
r
:
m_results
)
{
r
->
set_needs_default_layout
(
true
);
}
if
(
is_test
&&
use_cache
&&
t_in_cache_
.
find
(
func_cache_key_
)
!=
t_in_cache_
.
end
())
{
auto
m_parameters
=
ng_handle
->
get_parameters
();
auto
m_results
=
ng_handle
->
get_results
();
if
(
is_test
&&
t_in_cache_
.
find
(
func_cache_key_
)
!=
t_in_cache_
.
end
())
{
p_t_in
=
&
(
t_in_cache_
[
func_cache_key_
]);
for
(
size_t
i
=
0
;
i
<
p_var_in_updates
->
size
();
++
i
)
{
int
index
=
p_var_in_updates
->
at
(
i
);
...
...
@@ -601,7 +607,7 @@ void NgraphEngine::Run(const framework::Scope& scope,
}
}
}
else
{
if
(
is_test
&&
use_cache
)
{
if
(
is_test
)
{
p_t_in
=
&
(
t_in_cache_
[
func_cache_key_
]);
}
else
{
p_t_in
=
&
t_in
;
...
...
@@ -664,8 +670,7 @@ void NgraphEngine::Run(const framework::Scope& scope,
}
}
auto
handle
=
backend_
->
compile
(
ng_func
);
handle
->
call_with_validate
(
t_out
,
*
p_t_in
);
ng_handle
->
call
(
t_out
,
*
p_t_in
);
}
// NgraphEngine::Run
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/ngraph/ngraph_engine.h
浏览文件 @
e2c1b7c3
...
...
@@ -40,7 +40,7 @@ enum class OpState { /* nGraph support state on ops */
// cache engine repetitives
struct
EngineCache
{
std
::
shared_ptr
<
ngraph
::
Function
>
ngraph_function
;
std
::
shared_ptr
<
ngraph
::
runtime
::
Executable
>
ngraph_handle
;
std
::
set
<
std
::
string
>
persistables
;
std
::
vector
<
std
::
string
>
var_in
;
std
::
vector
<
std
::
string
>
var_out
;
...
...
@@ -84,8 +84,6 @@ class NgraphEngine {
// ngraph backend eg. CPU
static
std
::
shared_ptr
<
ngraph
::
runtime
::
Backend
>
backend_
;
// ngraph function to call and execute
std
::
shared_ptr
<
ngraph
::
Function
>
ngraph_function_
;
// var_name of inputs
std
::
vector
<
std
::
string
>
var_in_
;
// var_name of outputs from fetch in order
...
...
@@ -110,7 +108,10 @@ class NgraphEngine {
// Call ngraph bridge to map ops
void
BuildNgNodes
();
// build ngraph function call
void
BuildNgFunction
(
const
framework
::
ExecutionContext
&
ctx
);
std
::
shared_ptr
<
ngraph
::
Function
>
BuildNgFunction
(
const
framework
::
ExecutionContext
&
ctx
);
// clear ngraph engine cache and t_in cache
void
ClearNgCache
();
// Check cache for ngraph function or otherwise build the function
void
GetNgFunction
(
const
framework
::
ExecutionContext
&
ctx
);
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录