Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
22ac2133
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
22ac2133
编写于
12月 07, 2018
作者:
B
baojun-nervana
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rename class
test=develop
上级
bfde5e10
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
35 addition
and
36 deletion
+35
-36
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+4
-4
paddle/fluid/framework/ngraph_operator.cc
paddle/fluid/framework/ngraph_operator.cc
+28
-29
paddle/fluid/framework/ngraph_operator.h
paddle/fluid/framework/ngraph_operator.h
+3
-3
未找到文件。
paddle/fluid/framework/executor.cc
浏览文件 @
22ac2133
...
...
@@ -91,11 +91,11 @@ static void DeleteUnusedTensors(const Scope& scope, const OperatorBase* op,
static
void
EnableFusedOp
(
ExecutorPrepareContext
*
ctx
)
{
#ifdef PADDLE_WITH_NGRAPH
VLOG
(
3
)
<<
"use_ngraph=True"
;
auto
intervals
=
FusedOperator
::
Fused
OpIntervals
(
&
ctx
->
ops_
);
auto
intervals
=
NgraphOperator
::
Ngraph
OpIntervals
(
&
ctx
->
ops_
);
for
(
auto
&
interval
:
intervals
)
{
auto
*
fused_op
=
new
FusedOperator
(
ctx
->
prog_
,
ctx
->
block_id_
,
interval
.
at
(
0
),
interval
.
at
(
1
));
*
interval
[
0
]
=
std
::
unique_ptr
<
OperatorBase
>
(
fused
_op
);
auto
*
ng_op
=
new
NgraphOperator
(
ctx
->
prog_
,
ctx
->
block_id_
,
interval
.
at
(
0
)
,
interval
.
at
(
1
));
*
interval
[
0
]
=
std
::
unique_ptr
<
OperatorBase
>
(
ng
_op
);
}
for
(
auto
it
=
intervals
.
rbegin
();
it
!=
intervals
.
rend
();
++
it
)
{
ctx
->
ops_
.
erase
(
it
->
at
(
0
)
+
1
,
it
->
at
(
1
));
...
...
paddle/fluid/framework/ngraph_operator.cc
浏览文件 @
22ac2133
...
...
@@ -57,16 +57,16 @@ typedef enum { /* nGraph support state on ops */
}
op_state
;
// perform graph build through bridge and execute computation
class
Ngraph
Operator
{
class
Ngraph
Engine
{
public:
explicit
Ngraph
Operator
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>&
ops
,
const
std
::
unordered_map
<
std
::
string
,
ngraph
::
element
::
Type
>&
var_type_map
,
const
std
::
unordered_set
<
std
::
string
>&
persist
,
const
std
::
unordered_set
<
std
::
string
>&
fetches
,
const
std
::
unordered_set
<
std
::
string
>&
post_op_inputs
,
op_state
ng_op_state
)
explicit
Ngraph
Engine
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
,
const
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>&
ops
,
const
std
::
unordered_map
<
std
::
string
,
ngraph
::
element
::
Type
>&
var_type_map
,
const
std
::
unordered_set
<
std
::
string
>&
persist
,
const
std
::
unordered_set
<
std
::
string
>&
fetches
,
const
std
::
unordered_set
<
std
::
string
>&
post_op_inputs
,
op_state
ng_op_state
)
:
scope_
(
scope
),
place_
(
place
),
fused_ops_
(
ops
),
...
...
@@ -131,7 +131,7 @@ class NgraphOperator {
};
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
>>
FusedOperator
::
Fused
OpIntervals
(
NgraphOperator
::
Ngraph
OpIntervals
(
std
::
vector
<
std
::
unique_ptr
<
paddle
::
framework
::
OperatorBase
>>*
ops
)
{
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
>>
intervals
;
...
...
@@ -184,7 +184,7 @@ FusedOperator::FusedOpIntervals(
return
intervals
;
}
FusedOperator
::
Fused
Operator
(
NgraphOperator
::
Ngraph
Operator
(
const
ProgramDesc
&
prog
,
size_t
block_id
,
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
start
,
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
end
,
...
...
@@ -214,7 +214,7 @@ FusedOperator::FusedOperator(
Process
();
}
void
Fused
Operator
::
Process
()
{
void
Ngraph
Operator
::
Process
()
{
auto
&
bdesc
=
pdesc_
.
Block
(
block_
);
for
(
auto
&
var
:
bdesc
.
AllVars
())
{
if
(
!
(
var
->
GetType
()
==
proto
::
VarType
::
SELECTED_ROWS
||
...
...
@@ -250,8 +250,8 @@ void FusedOperator::Process() {
}
}
void
Fused
Operator
::
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
void
Ngraph
Operator
::
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
op_state
ng_op_state
=
PARTIAL_TEST
;
auto
&
bdesc
=
pdesc_
.
Block
(
block_
);
for
(
auto
*
op
:
bdesc
.
AllOps
())
{
...
...
@@ -265,19 +265,19 @@ void FusedOperator::RunImpl(const Scope& scope,
ng_op_state
=
ng_op_state
==
PARTIAL_TEST
?
FULL_TEST
:
FULL_TRAIN
;
}
Ngraph
Operator
ngraph_op
(
scope
,
place
,
fused_ops_
,
var_type_map_
,
persistables_
,
fetches_
,
post_op_inputs_
,
ng_op_state
);
ngraph_
op
.
Run
(
scope
,
place
);
Ngraph
Engine
ngraph_engine
(
scope
,
place
,
fused_ops_
,
var_type_map_
,
persistables_
,
fetches_
,
post_op_inputs_
,
ng_op_state
);
ngraph_
engine
.
Run
(
scope
,
place
);
}
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Function
>>
Ngraph
Operator
::
func_cache_
=
{};
Ngraph
Engine
::
func_cache_
=
{};
std
::
shared_ptr
<
ngraph
::
runtime
::
Backend
>
Ngraph
Operator
::
backend_
=
std
::
shared_ptr
<
ngraph
::
runtime
::
Backend
>
Ngraph
Engine
::
backend_
=
ngraph
::
runtime
::
Backend
::
create
(
"CPU"
);
void
Ngraph
Operator
::
GetNgInputShape
(
std
::
shared_ptr
<
OperatorBase
>
op
)
{
void
Ngraph
Engine
::
GetNgInputShape
(
std
::
shared_ptr
<
OperatorBase
>
op
)
{
op
->
RuntimeInferShape
(
scope_
,
place_
);
for
(
auto
&
var_name_item
:
op
->
Inputs
())
{
for
(
auto
&
var_name
:
var_name_item
.
second
)
{
...
...
@@ -300,7 +300,7 @@ void NgraphOperator::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
}
}
void
Ngraph
Operator
::
BuildNgNodes
()
{
void
Ngraph
Engine
::
BuildNgNodes
()
{
for
(
auto
&
var_name
:
var_out_
)
{
if
(
var_node_map_
->
find
(
var_name
)
==
var_node_map_
->
end
())
{
auto
*
var
=
scope_
.
FindVar
(
var_name
);
...
...
@@ -322,7 +322,7 @@ void NgraphOperator::BuildNgNodes() {
}
}
void
Ngraph
Operator
::
BuildNgIO
()
{
void
Ngraph
Engine
::
BuildNgIO
()
{
std
::
unordered_set
<
std
::
string
>
inputs
;
std
::
unordered_set
<
std
::
string
>
outputs
;
...
...
@@ -394,7 +394,7 @@ void NgraphOperator::BuildNgIO() {
}
}
void
Ngraph
Operator
::
BuildNgFunction
()
{
void
Ngraph
Engine
::
BuildNgFunction
()
{
BuildNgNodes
();
ngraph_function_
=
nullptr
;
ngraph
::
NodeVector
func_outputs
;
...
...
@@ -415,7 +415,7 @@ void NgraphOperator::BuildNgFunction() {
std
::
make_shared
<
ngraph
::
Function
>
(
func_outputs
,
func_inputs
);
}
std
::
shared_ptr
<
std
::
string
>
Ngraph
Operator
::
GetCacheKey
()
{
std
::
shared_ptr
<
std
::
string
>
Ngraph
Engine
::
GetCacheKey
()
{
auto
cache_key
=
std
::
make_shared
<
std
::
string
>
(
""
);
*
cache_key
+=
std
::
to_string
(
fused_ops_
.
size
());
for
(
auto
&
op
:
fused_ops_
)
{
...
...
@@ -443,7 +443,7 @@ std::shared_ptr<std::string> NgraphOperator::GetCacheKey() {
return
cache_key
;
}
void
Ngraph
Operator
::
GetNgFunction
()
{
void
Ngraph
Engine
::
GetNgFunction
()
{
bool
cache_on
=
true
;
if
(
cache_on
)
{
std
::
string
cache_key_val
=
*
GetCacheKey
();
...
...
@@ -458,8 +458,7 @@ void NgraphOperator::GetNgFunction() {
}
}
void
NgraphOperator
::
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
void
NgraphEngine
::
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Tensor
>>
t_in
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Tensor
>>
t_out
;
...
...
@@ -544,6 +543,6 @@ void NgraphOperator::Run(const Scope& scope,
}
backend_
->
call
(
ngraph_function_
,
t_out
,
t_in
);
}
// Ngraph
Operator
::RunImpl
}
// Ngraph
Engine
::RunImpl
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ngraph_operator.h
浏览文件 @
22ac2133
...
...
@@ -32,14 +32,14 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
class
Fused
Operator
:
public
OperatorBase
{
class
Ngraph
Operator
:
public
OperatorBase
{
public:
static
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
>>
Fused
OpIntervals
(
Ngraph
OpIntervals
(
std
::
vector
<
std
::
unique_ptr
<
paddle
::
framework
::
OperatorBase
>>*
ops
);
explicit
Fused
Operator
(
explicit
Ngraph
Operator
(
const
ProgramDesc
&
prog
,
size_t
block_id
,
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
start
,
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
end
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录