Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
adcfc53b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
adcfc53b
编写于
7月 30, 2019
作者:
B
baojun
提交者:
Tao Luo
7月 31, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
upgrade ngraph version and simplify ngraph engine (#18853)
* upgrade ngraph to v0.24 test=develop * simplify io test=develop
上级
2bb296df
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
24 addition
and
68 deletion
+24
-68
cmake/external/ngraph.cmake
cmake/external/ngraph.cmake
+1
-1
paddle/fluid/operators/ngraph/ngraph_engine.cc
paddle/fluid/operators/ngraph/ngraph_engine.cc
+23
-58
paddle/fluid/operators/ngraph/ngraph_engine.h
paddle/fluid/operators/ngraph/ngraph_engine.h
+0
-9
未找到文件。
cmake/external/ngraph.cmake
浏览文件 @
adcfc53b
...
...
@@ -37,7 +37,7 @@ INCLUDE(GNUInstallDirs)
INCLUDE
(
ExternalProject
)
SET
(
NGRAPH_PROJECT
"extern_ngraph"
)
SET
(
NGRAPH_GIT_TAG
"
4ec94acc11084a5d53418f565529310fa584899a
"
)
SET
(
NGRAPH_GIT_TAG
"
v0.24.0-rc.2
"
)
SET
(
NGRAPH_SOURCES_DIR
${
THIRD_PARTY_PATH
}
/ngraph
)
SET
(
NGRAPH_INSTALL_DIR
${
THIRD_PARTY_PATH
}
/install/ngraph
)
SET
(
NGRAPH_INC_DIR
${
NGRAPH_INSTALL_DIR
}
/include
)
...
...
paddle/fluid/operators/ngraph/ngraph_engine.cc
浏览文件 @
adcfc53b
...
...
@@ -92,13 +92,20 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
std
::
vector
<
std
::
vector
<
int
>>
intervals
;
int
size
=
ops
->
size
();
int
left
=
0
;
int
left
=
0
,
feed_idx
=
-
1
;
while
(
left
<
size
&&
ops
->
at
(
left
)
->
Type
()
!=
framework
::
kFeedOpType
&&
ops
->
at
(
left
)
->
Type
()
!=
"read"
&&
ops
->
at
(
left
)
->
Type
()
!=
framework
::
kFetchOpType
)
{
++
left
;
}
if
(
left
<
size
)
{
auto
op_type
=
ops
->
at
(
left
)
->
Type
();
if
(
op_type
==
framework
::
kFeedOpType
||
op_type
==
"read"
)
{
feed_idx
=
left
;
}
}
while
(
left
<
size
&&
(
ops
->
at
(
left
)
->
Type
()
==
framework
::
kFeedOpType
||
ops
->
at
(
left
)
->
Type
()
==
"read"
))
{
for
(
auto
&
var_name_item
:
ops
->
at
(
left
)
->
Outputs
())
{
...
...
@@ -141,7 +148,9 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
++
end
;
}
std
::
vector
<
int
>
interval
=
{
start
,
end
};
intervals
.
emplace_back
(
interval
);
if
(
feed_idx
!=
-
1
&&
start
>
feed_idx
)
{
intervals
.
emplace_back
(
interval
);
}
}
}
// end while
return
intervals
;
...
...
@@ -252,7 +261,6 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
NgraphEngine
::
p_bdesc
=
&
block_desc
;
}
bool
has_fetch
=
false
,
is_full
=
false
;
for
(
auto
&
var
:
p_bdesc
->
AllVars
())
{
if
(
!
(
var
->
GetType
()
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
||
var
->
GetType
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
||
...
...
@@ -283,33 +291,12 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
std
::
vector
<
paddle
::
framework
::
OpDesc
*>
ops_desc
;
for
(
auto
op_desc
:
p_bdesc
->
AllOps
())
{
ops_desc
.
emplace_back
(
op_desc
);
if
(
op_desc
->
Type
()
==
framework
::
kFetchOpType
)
{
has_fetch
=
true
;
}
}
for
(
auto
op_desc
:
ops_desc
)
{
if
(
op_desc
->
Type
().
find
(
"_grad"
)
!=
std
::
string
::
npos
)
{
is_training
=
true
;
this
->
is_test_
=
false
;
break
;
}
}
if
(
interval
[
0
]
>
0
&&
ops_desc
.
at
(
interval
[
0
]
-
1
)
->
Type
()
==
framework
::
kFeedOpType
&&
interval
[
1
]
<
static_cast
<
int
>
(
ops_desc
.
size
())
&&
ops_desc
.
at
(
interval
[
1
])
->
Type
()
==
framework
::
kFetchOpType
)
{
is_full
=
true
;
}
if
(
is_full
)
{
this
->
op_state_
=
this
->
is_test_
?
OpState
::
FULL_TEST
:
OpState
::
FULL_TRAIN
;
}
else
{
this
->
op_state_
=
this
->
is_test_
?
OpState
::
PARTIAL_TEST
:
OpState
::
PARTIAL_TRAIN
;
}
int
idx
=
interval
[
0
];
while
(
idx
<
interval
[
1
])
{
this
->
fused_ops_
.
emplace_back
(
...
...
@@ -327,10 +314,6 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
++
idx
;
}
if
(
!
has_fetch
)
{
op_state_
=
OpState
::
UNKNOWN
;
}
if
(
var_in_
.
empty
()
&&
var_out_
.
empty
())
{
BuildNgIO
(
ops_desc
,
interval
);
}
...
...
@@ -380,37 +363,19 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
"op %s has more than 1 output - Not handling yet"
,
op
->
Type
());
for
(
auto
&
var_name
:
var_name_item
.
second
)
{
switch
(
this
->
op_state_
)
{
case
OpState
::
PARTIAL_TEST
:
if
(
post_op_inputs_
.
find
(
var_name
)
!=
post_op_inputs_
.
end
()
||
find
(
fetch_vars
.
begin
(),
fetch_vars
.
end
(),
var_name
)
!=
fetch_vars
.
end
())
{
this
->
var_out_
.
emplace_back
(
var_name
);
}
break
;
case
OpState
::
FULL_TEST
:
if
(
find
(
fetch_vars
.
begin
(),
fetch_vars
.
end
(),
var_name
)
!=
fetch_vars
.
end
())
{
this
->
var_out_
.
emplace_back
(
var_name
);
}
break
;
case
OpState
::
PARTIAL_TRAIN
:
if
(
find
(
fetch_vars
.
begin
(),
fetch_vars
.
end
(),
var_name
)
!=
fetch_vars
.
end
()
||
post_op_inputs_
.
find
(
var_name
)
!=
post_op_inputs_
.
end
()
||
persistables_
.
find
(
var_name
)
!=
persistables_
.
end
())
{
this
->
var_out_
.
emplace_back
(
var_name
);
}
break
;
case
OpState
::
FULL_TRAIN
:
if
(
find
(
fetch_vars
.
begin
(),
fetch_vars
.
end
(),
var_name
)
!=
fetch_vars
.
end
()
||
persistables_
.
find
(
var_name
)
!=
persistables_
.
end
())
{
this
->
var_out_
.
emplace_back
(
var_name
);
}
break
;
default:
if
(
this
->
is_test_
)
{
if
(
post_op_inputs_
.
find
(
var_name
)
!=
post_op_inputs_
.
end
()
||
find
(
fetch_vars
.
begin
(),
fetch_vars
.
end
(),
var_name
)
!=
fetch_vars
.
end
())
{
this
->
var_out_
.
emplace_back
(
var_name
);
}
}
else
{
if
(
find
(
fetch_vars
.
begin
(),
fetch_vars
.
end
(),
var_name
)
!=
fetch_vars
.
end
()
||
post_op_inputs_
.
find
(
var_name
)
!=
post_op_inputs_
.
end
()
||
persistables_
.
find
(
var_name
)
!=
persistables_
.
end
())
{
this
->
var_out_
.
emplace_back
(
var_name
);
}
}
}
}
...
...
paddle/fluid/operators/ngraph/ngraph_engine.h
浏览文件 @
adcfc53b
...
...
@@ -30,14 +30,6 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
enum
class
OpState
{
/* nGraph support state on ops */
FULL_TRAIN
,
/* Support full ops for train */
PARTIAL_TRAIN
,
/* Support partial ops for train */
FULL_TEST
,
/* Support full list of ops for test */
PARTIAL_TEST
,
/* Support partial list of ops for test */
UNKNOWN
/* Output all for debug purpose */
};
// cache engine repetitives
struct
EngineCache
{
std
::
shared_ptr
<
ngraph
::
runtime
::
Executable
>
ngraph_handle
;
...
...
@@ -78,7 +70,6 @@ class NgraphEngine {
std
::
unordered_map
<
std
::
string
,
ngraph
::
element
::
Type
>
var_type_map_
;
std
::
set
<
std
::
string
>
persistables_
;
std
::
unordered_set
<
std
::
string
>
post_op_inputs_
;
OpState
op_state_
=
OpState
::
UNKNOWN
;
bool
is_test_
{
true
};
std
::
string
func_cache_key_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录