Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
51a538e0
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
51a538e0
编写于
11月 13, 2018
作者:
B
baojun-nervana
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix style and use enum
test=develop
上级
ea3538d8
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
51 addition
and
47 deletion
+51
-47
paddle/fluid/framework/ngraph_operator.cc
paddle/fluid/framework/ngraph_operator.cc
+42
-38
paddle/fluid/framework/ngraph_operator.h
paddle/fluid/framework/ngraph_operator.h
+9
-9
未找到文件。
paddle/fluid/framework/ngraph_operator.cc
浏览文件 @
51a538e0
...
...
@@ -35,6 +35,13 @@ static std::map<proto::VarType::Type, ngraph::element::Type> pd2ng_type_map = {
{
proto
::
VarType
::
BOOL
,
ngraph
::
element
::
boolean
},
};
typedef
enum
{
/* nGraph support state on ops */
FULL_TRAIN
,
/* Support full ops for train */
PARTIAL_TRAIN
,
/* Support partial ops for train */
FULL_TEST
,
/* Support full list of ops for test */
PARTIAL_TEST
/* Support partial list of ops for test */
}
op_state
;
class
NgraphOperator
{
public:
explicit
NgraphOperator
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
,
...
...
@@ -44,33 +51,29 @@ class NgraphOperator {
const
std
::
unordered_set
<
std
::
string
>&
persist
,
const
std
::
unordered_set
<
std
::
string
>&
fetches
,
const
std
::
unordered_set
<
std
::
string
>&
post_op_inputs
,
int
is_test_or_train
)
:
scope
(
scope
),
place
(
place
),
fused_ops
(
ops
),
var_type_map
(
var_type_map
),
persistables
(
persist
),
fetches
(
fetches
),
post_op_inputs
(
post_op_inputs
),
is_test_or_train
(
is_test_or_train
)
{}
op_state
ng_op_state
)
:
scope
_
(
scope
),
place
_
(
place
),
fused_ops
_
(
ops
),
var_type_map
_
(
var_type_map
),
persistables
_
(
persist
),
fetches
_
(
fetches
),
post_op_inputs
_
(
post_op_inputs
),
ng_op_state_
(
ng_op_state
)
{}
void
Run
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
;
private:
static
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Function
>>
func_cache
;
const
Scope
&
scope
;
const
platform
::
Place
&
place
;
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
fused_ops
;
std
::
unordered_map
<
std
::
string
,
ngraph
::
element
::
Type
>
var_type_map
;
std
::
unordered_set
<
std
::
string
>
persistables
;
std
::
unordered_set
<
std
::
string
>
fetches
;
std
::
unordered_set
<
std
::
string
>
post_op_inputs
;
// 0 = default; 1 = (is_test && not is_complete)
// 2 = (is_test && is_complete)
// 3 = (is_training && not is_complete)
// 4 = (is_training && is_complete)
int
is_test_or_train
;
const
Scope
&
scope_
;
const
platform
::
Place
&
place_
;
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
fused_ops_
;
std
::
unordered_map
<
std
::
string
,
ngraph
::
element
::
Type
>
var_type_map_
;
std
::
unordered_set
<
std
::
string
>
persistables_
;
std
::
unordered_set
<
std
::
string
>
fetches_
;
std
::
unordered_set
<
std
::
string
>
post_op_inputs_
;
op_state
ng_op_state_
;
};
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
>>
...
...
@@ -131,19 +134,19 @@ FusedOperator::FusedOperator(
const
ProgramDesc
&
prog
,
size_t
block_id
,
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
start
,
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
end
,
const
std
::
string
&
type
=
"fused_op"
,
const
VariableNameMap
&
inputs
=
{}
,
const
VariableNameMap
&
outputs
=
{},
const
AttributeMap
&
attrs
=
{}
)
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
),
pdesc
(
prog
),
block
(
block_id
)
{
for
(
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
it
=
start
;
it
!=
end
;
++
it
)
{
fused_ops
.
push_back
(
std
::
move
(
*
it
));
fused_ops
_
.
push_back
(
std
::
move
(
*
it
));
}
for
(
std
::
vector
<
std
::
unique_ptr
<
OperatorBase
>>::
iterator
it
=
end
;
(
*
it
)
->
Type
()
!=
kFetchOpType
;
++
it
)
{
for
(
auto
&
var_name_item
:
(
*
it
)
->
Inputs
())
{
for
(
auto
&
var_name
:
var_name_item
.
second
)
{
post_op_inputs
.
insert
(
var_name
);
post_op_inputs
_
.
insert
(
var_name
);
}
}
}
...
...
@@ -152,11 +155,11 @@ FusedOperator::FusedOperator(
is_complete
=
true
;
}
p
rocess
();
P
rocess
();
}
void
FusedOperator
::
p
rocess
()
{
auto
&
bdesc
=
pdesc
.
Block
(
block
);
void
FusedOperator
::
P
rocess
()
{
auto
&
bdesc
=
pdesc
_
.
Block
(
block_
);
for
(
auto
&
var
:
bdesc
.
AllVars
())
{
if
(
!
(
var
->
GetType
()
==
proto
::
VarType
::
SELECTED_ROWS
||
var
->
GetType
()
==
proto
::
VarType
::
LOD_TENSOR
||
...
...
@@ -175,39 +178,40 @@ void FusedOperator::process() {
PADDLE_THROW
(
"Data type of var %s not found in pd2ng_type_map"
,
var_name
);
}
var_type_map
[
var_name
]
=
pd2ng_type_map
[
pd_type
];
var_type_map
_
[
var_name
]
=
pd2ng_type_map
[
pd_type
];
}
if
(
var
->
Persistable
())
{
persistables
.
insert
(
var
->
Name
());
persistables
_
.
insert
(
var
->
Name
());
}
}
for
(
auto
*
op
:
bdesc
.
AllOps
())
{
if
(
op
->
Type
()
==
kFetchOpType
)
{
std
::
string
fetch_target_name
=
op
->
Input
(
"X"
)[
0
];
fetches
.
insert
(
fetch_target_name
);
fetches
_
.
insert
(
fetch_target_name
);
}
}
}
void
FusedOperator
::
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
{
int
is_test_or_train
=
1
;
auto
&
bdesc
=
pdesc
.
Block
(
block
);
op_state
ng_op_state
=
PARTIAL_TEST
;
auto
&
bdesc
=
pdesc
_
.
Block
(
block_
);
for
(
auto
*
op
:
bdesc
.
AllOps
())
{
if
(
op
->
Type
().
find
(
"_grad"
)
!=
std
::
string
::
npos
)
{
is_test_or_train
=
3
;
ng_op_state
=
PARTIAL_TRAIN
;
break
;
}
}
if
(
is_
complete
)
{
is_test_or_train
=
is_test_or_train
==
1
?
2
:
4
;
if
(
is_
full
)
{
ng_op_state
=
ng_op_state
==
PARTIAL_TEST
?
FULL_TEST
:
FULL_TRAIN
;
}
NgraphOperator
ngraph_op
(
scope
,
place
,
fused_ops
,
var_type_map
,
persistables
,
fetches
,
post_op_inputs
,
is_test_or_train
);
NgraphOperator
ngraph_op
(
scope
,
place
,
fused_ops_
,
var_type_map_
,
persistables_
,
fetches_
,
post_op_inputs_
,
ng_op_state
);
ngraph_op
.
Run
(
scope
,
place
);
}
...
...
paddle/fluid/framework/ngraph_operator.h
浏览文件 @
51a538e0
...
...
@@ -56,16 +56,16 @@ class FusedOperator : public OperatorBase {
void
RunImpl
(
const
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
final
;
private:
const
ProgramDesc
pdesc
;
size_t
block
;
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
fused_ops
;
std
::
unordered_map
<
std
::
string
,
ngraph
::
element
::
Type
>
var_type_map
;
std
::
unordered_set
<
std
::
string
>
persistables
;
std
::
unordered_set
<
std
::
string
>
fetches
;
std
::
unordered_set
<
std
::
string
>
post_op_inputs
;
bool
is_
complete
=
false
;
const
ProgramDesc
pdesc
_
;
size_t
block
_
;
std
::
vector
<
std
::
shared_ptr
<
OperatorBase
>>
fused_ops
_
;
std
::
unordered_map
<
std
::
string
,
ngraph
::
element
::
Type
>
var_type_map
_
;
std
::
unordered_set
<
std
::
string
>
persistables
_
;
std
::
unordered_set
<
std
::
string
>
fetches
_
;
std
::
unordered_set
<
std
::
string
>
post_op_inputs
_
;
bool
is_
full_
=
false
;
void
p
rocess
();
void
P
rocess
();
};
}
// namespace framework
}
// namespace paddle
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录