Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a0af374f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a0af374f
编写于
4月 08, 2019
作者:
B
baojun
提交者:
tensor-tang
4月 09, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix training validation test=release/1.4 (#16716)
上级
266cdf7d
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
7 addition
and
2 deletion
+7
-2
paddle/fluid/operators/ngraph/ngraph_engine.cc
paddle/fluid/operators/ngraph/ngraph_engine.cc
+6
-2
paddle/fluid/operators/ngraph/ngraph_engine.h
paddle/fluid/operators/ngraph/ngraph_engine.h
+1
-0
未找到文件。
paddle/fluid/operators/ngraph/ngraph_engine.cc
浏览文件 @
a0af374f
...
@@ -75,6 +75,7 @@ std::vector<std::string> NgraphEngine::feed_vars = {};
...
@@ -75,6 +75,7 @@ std::vector<std::string> NgraphEngine::feed_vars = {};
std
::
vector
<
std
::
string
>
NgraphEngine
::
fetch_vars
=
{};
std
::
vector
<
std
::
string
>
NgraphEngine
::
fetch_vars
=
{};
framework
::
Variable
*
NgraphEngine
::
pre_var_ptr
=
nullptr
;
framework
::
Variable
*
NgraphEngine
::
pre_var_ptr
=
nullptr
;
const
framework
::
BlockDesc
*
NgraphEngine
::
p_bdesc
=
nullptr
;
const
framework
::
BlockDesc
*
NgraphEngine
::
p_bdesc
=
nullptr
;
bool
NgraphEngine
::
is_training
=
false
;
std
::
unordered_map
<
std
::
string
,
EngineCache
>
NgraphEngine
::
engine_cache
=
{};
std
::
unordered_map
<
std
::
string
,
EngineCache
>
NgraphEngine
::
engine_cache
=
{};
std
::
unordered_map
<
std
::
string
,
std
::
unordered_map
<
std
::
string
,
...
@@ -93,11 +94,13 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
...
@@ -93,11 +94,13 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
int
size
=
ops
->
size
();
int
size
=
ops
->
size
();
int
left
=
0
;
int
left
=
0
;
while
(
left
<
size
&&
ops
->
at
(
left
)
->
Type
()
!=
framework
::
kFeedOpType
&&
while
(
left
<
size
&&
ops
->
at
(
left
)
->
Type
()
!=
framework
::
kFeedOpType
&&
ops
->
at
(
left
)
->
Type
()
!=
"read"
&&
ops
->
at
(
left
)
->
Type
()
!=
framework
::
kFetchOpType
)
{
ops
->
at
(
left
)
->
Type
()
!=
framework
::
kFetchOpType
)
{
++
left
;
++
left
;
}
}
while
(
left
<
size
&&
ops
->
at
(
left
)
->
Type
()
==
framework
::
kFeedOpType
)
{
while
(
left
<
size
&&
(
ops
->
at
(
left
)
->
Type
()
==
framework
::
kFeedOpType
||
ops
->
at
(
left
)
->
Type
()
==
"read"
))
{
for
(
auto
&
var_name_item
:
ops
->
at
(
left
)
->
Outputs
())
{
for
(
auto
&
var_name_item
:
ops
->
at
(
left
)
->
Outputs
())
{
for
(
auto
&
var_name
:
var_name_item
.
second
)
{
for
(
auto
&
var_name
:
var_name_item
.
second
)
{
NgraphEngine
::
feed_vars
.
emplace_back
(
var_name
);
NgraphEngine
::
feed_vars
.
emplace_back
(
var_name
);
...
@@ -270,6 +273,7 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
...
@@ -270,6 +273,7 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
for
(
auto
op_desc
:
ops_desc
)
{
for
(
auto
op_desc
:
ops_desc
)
{
if
(
op_desc
->
Type
().
find
(
"_grad"
)
!=
std
::
string
::
npos
)
{
if
(
op_desc
->
Type
().
find
(
"_grad"
)
!=
std
::
string
::
npos
)
{
is_training
=
true
;
this
->
is_test_
=
false
;
this
->
is_test_
=
false
;
break
;
break
;
}
}
...
@@ -590,7 +594,7 @@ void NgraphEngine::Run(const framework::Scope& scope,
...
@@ -590,7 +594,7 @@ void NgraphEngine::Run(const framework::Scope& scope,
}
}
bool
is_persistable
=
bool
is_persistable
=
(
p_persistables
->
find
(
vi
)
!=
p_persistables
->
end
())
?
true
:
false
;
(
p_persistables
->
find
(
vi
)
!=
p_persistables
->
end
())
?
true
:
false
;
if
(
is_test
&&
is_persistable
)
{
if
(
!
is_training
&&
is_test
&&
is_persistable
)
{
ti
->
set_stale
(
false
);
ti
->
set_stale
(
false
);
}
}
(
*
p_t_in
).
emplace_back
(
ti
);
(
*
p_t_in
).
emplace_back
(
ti
);
...
...
paddle/fluid/operators/ngraph/ngraph_engine.h
浏览文件 @
a0af374f
...
@@ -57,6 +57,7 @@ class NgraphEngine {
...
@@ -57,6 +57,7 @@ class NgraphEngine {
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
;
void
Run
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
;
static
bool
is_training
;
static
const
framework
::
BlockDesc
*
p_bdesc
;
static
const
framework
::
BlockDesc
*
p_bdesc
;
static
std
::
vector
<
std
::
string
>
feed_vars
,
fetch_vars
;
static
std
::
vector
<
std
::
string
>
feed_vars
,
fetch_vars
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录