Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
298ee7d2
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
298ee7d2
编写于
1月 09, 2020
作者:
B
baojun
提交者:
Tao Luo
1月 10, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Improve ngraph file line coverage (#22155)
上级
d0f0a252
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
20 addition
and
74 deletion
+20
-74
paddle/fluid/operators/ngraph/ngraph_engine.cc
paddle/fluid/operators/ngraph/ngraph_engine.cc
+0
-30
paddle/fluid/operators/ngraph/ops/cast_op.h
paddle/fluid/operators/ngraph/ops/cast_op.h
+0
-15
paddle/fluid/operators/ngraph/ops/elementwise_node.h
paddle/fluid/operators/ngraph/ops/elementwise_node.h
+1
-3
paddle/fluid/operators/ngraph/ops/reshape_op.h
paddle/fluid/operators/ngraph/ops/reshape_op.h
+8
-12
paddle/fluid/operators/ngraph/ops/sum_op.h
paddle/fluid/operators/ngraph/ops/sum_op.h
+9
-8
python/paddle/fluid/tests/unittests/ngraph/test_compare_ngraph_op.py
...le/fluid/tests/unittests/ngraph/test_compare_ngraph_op.py
+1
-1
python/paddle/fluid/tests/unittests/ngraph/test_logical_ngraph_op.py
...le/fluid/tests/unittests/ngraph/test_logical_ngraph_op.py
+1
-5
未找到文件。
paddle/fluid/operators/ngraph/ngraph_engine.cc
浏览文件 @
298ee7d2
...
...
@@ -177,36 +177,6 @@ std::string SerializedBlock(const framework::BlockDesc& bdesc) {
return
block_desc
.
Proto
()
->
SerializeAsString
();
}
std
::
string
GenerateEngineKey
(
const
framework
::
BlockDesc
&
bdesc
)
{
framework
::
proto
::
BlockDesc
block_proto
;
framework
::
BlockDesc
block_desc
(
nullptr
,
&
block_proto
);
block_desc
.
Proto
()
->
set_parent_idx
(
-
1
);
block_desc
.
Proto
()
->
set_idx
(
0
);
for
(
auto
&
op_desc
:
bdesc
.
AllOps
())
{
auto
*
op
=
block_desc
.
AppendOp
();
*
op
->
Proto
()
=
*
op_desc
->
Proto
();
}
auto
engine_key
=
std
::
to_string
(
std
::
hash
<
std
::
string
>
()(
block_desc
.
Proto
()
->
SerializeAsString
()));
return
engine_key
;
}
std
::
string
GenerateEngineKey
(
const
std
::
vector
<
std
::
string
>&
engine_inputs
,
const
std
::
vector
<
std
::
string
>&
engine_outputs
,
int
size
)
{
std
::
string
engine_hash_key
=
""
;
for
(
auto
name
:
engine_inputs
)
{
engine_hash_key
+=
name
;
}
for
(
auto
name
:
engine_outputs
)
{
engine_hash_key
+=
name
;
}
engine_hash_key
+=
std
::
to_string
(
size
);
auto
engine_key
=
std
::
to_string
(
std
::
hash
<
std
::
string
>
()(
engine_hash_key
));
return
engine_key
;
}
void
NgraphEngine
::
FuseNgraphOps
(
const
framework
::
BlockDesc
&
block_desc
,
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>*
ops
)
{
...
...
paddle/fluid/operators/ngraph/ops/cast_op.h
浏览文件 @
298ee7d2
...
...
@@ -40,23 +40,8 @@ static void BuildCastNode(
auto
out
=
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
input
,
ng_dtype
);
paddle
::
platform
::
SetOutputNode
(
op
,
"Out"
,
out
,
ngb_node_map
);
}
static
void
BuildCastGradNode
(
const
std
::
shared_ptr
<
framework
::
OperatorBase
>&
op
,
std
::
shared_ptr
<
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
ngraph
::
Node
>>>
ngb_node_map
)
{
auto
input
=
platform
::
GetInputNode
(
op
,
"Out@GRAD"
,
ngb_node_map
);
auto
op_attrs
=
framework
::
AttrReader
(
op
->
Attrs
());
auto
ng_dtype
=
platform
::
GetNgType
(
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
op_attrs
.
Get
<
int
>
(
"out_dtype"
)));
auto
out
=
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
input
,
ng_dtype
);
platform
::
SetOutputNode
(
op
,
"X@GRAD"
,
out
,
ngb_node_map
);
}
}
// namespace ngraphs
}
// namespace operators
}
// namespace paddle
REGISTER_NG_OP
(
cast
,
BuildCastNode
);
REGISTER_NG_OP
(
cast_grad
,
BuildCastGradNode
);
paddle/fluid/operators/ngraph/ops/elementwise_node.h
浏览文件 @
298ee7d2
...
...
@@ -37,9 +37,7 @@ void BuildElementwiseBinaryNode(
std
::
shared_ptr
<
ngraph
::
Node
>&
x
=
nodes
.
at
(
0
);
std
::
shared_ptr
<
ngraph
::
Node
>&
y
=
nodes
.
at
(
1
);
if
(
x
->
get_element_type
()
!=
y
->
get_element_type
())
{
y
=
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
y
,
x
->
get_element_type
());
}
y
=
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
y
,
x
->
get_element_type
());
auto
out
=
std
::
make_shared
<
T
>
(
x
,
y
);
paddle
::
platform
::
SetOutputNode
(
op
,
"Out"
,
out
,
ngb_node_map
);
}
...
...
paddle/fluid/operators/ngraph/ops/reshape_op.h
浏览文件 @
298ee7d2
...
...
@@ -23,6 +23,7 @@ limitations under the License. */
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/operators/ngraph/ops/op_bridge.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/ngraph_helper.h"
namespace
paddle
{
...
...
@@ -60,20 +61,16 @@ static void BuildReshapeNode(
std
::
shared_ptr
<
ngraph
::
Node
>
shape
=
platform
::
GetInputNode
(
op
,
"Shape"
,
ngb_node_map
);
PADDLE_ENFORCE_EQ
(
shape
,
nullptr
,
platform
::
errors
::
Unimplemented
(
"Support for Shape input is not implemented"
));
auto
op_attrs
=
framework
::
AttrReader
(
op
->
Attrs
());
std
::
vector
<
int
>
v_shape
=
op_attrs
.
Get
<
std
::
vector
<
int
>>
(
"shape"
);
auto
out
=
input
;
if
(
shape
!=
nullptr
)
{
ngraph
::
Shape
new_shape
;
for
(
auto
&
it
:
shape
->
get_shape
())
{
new_shape
.
push_back
(
it
);
}
out
=
platform
::
NgReshaper
(
input
,
shape
->
get_shape
());
}
else
{
auto
out_shape
=
calc_output_shape
(
input_shape
,
v_shape
);
out
=
platform
::
NgReshaper
(
input
,
out_shape
);
}
auto
out_shape
=
calc_output_shape
(
input_shape
,
v_shape
);
auto
out
=
platform
::
NgReshaper
(
input
,
out_shape
);
platform
::
SetOutputNode
(
op
,
"Out"
,
out
,
ngb_node_map
);
if
(
is_v2
)
{
ngraph
::
Shape
input_xshape
(
input_shape
.
size
()
+
1
);
...
...
@@ -83,7 +80,6 @@ static void BuildReshapeNode(
input
->
get_element_type
(),
input_xshape
,
std
::
vector
<
std
::
string
>
{});
platform
::
SetOutputNode
(
op
,
"XShape"
,
xshape_node
,
ngb_node_map
);
}
platform
::
SetOutputNode
(
op
,
"Out"
,
out
,
ngb_node_map
);
}
template
<
bool
is_v2
>
...
...
paddle/fluid/operators/ngraph/ops/sum_op.h
浏览文件 @
298ee7d2
...
...
@@ -14,7 +14,9 @@ limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "ngraph/ngraph.hpp"
...
...
@@ -34,19 +36,18 @@ void BuildSumNode(
for
(
auto
&
var_name_item
:
op
->
Inputs
())
{
for
(
auto
&
var_name
:
var_name_item
.
second
)
{
op_inputs
.
push_back
(
var_name
);
if
(
ngb_node_map
->
find
(
var_name
)
==
ngb_node_map
->
end
())
{
PADDLE_THROW
(
"op % input varname %s is not found in var_node_map"
,
op
->
Type
(),
var_name
);
}
PADDLE_ENFORCE_NE
(
ngb_node_map
->
find
(
var_name
),
ngb_node_map
->
end
(),
platform
::
errors
::
NotFound
(
"op %s input varname %s is not found in var_node_map"
,
op
->
Type
(),
var_name
));
}
}
std
::
shared_ptr
<
ngraph
::
Node
>&
sum
=
ngb_node_map
->
at
(
op_inputs
[
0
]);
for
(
size_t
k
=
1
;
k
<
op_inputs
.
size
();
++
k
)
{
std
::
shared_ptr
<
ngraph
::
Node
>&
nodek
=
ngb_node_map
->
at
(
op_inputs
[
k
]);
if
(
nodek
->
get_element_type
()
!=
sum
->
get_element_type
())
{
nodek
=
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
nodek
,
sum
->
get_element_type
());
}
nodek
=
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
nodek
,
sum
->
get_element_type
());
sum
=
sum
+
nodek
;
}
platform
::
SetOutputNode
(
op
,
"Out"
,
sum
,
ngb_node_map
);
...
...
python/paddle/fluid/tests/unittests/ngraph/test_compare_ngraph_op.py
浏览文件 @
298ee7d2
...
...
@@ -17,7 +17,7 @@ from __future__ import print_function
import
unittest
import
sys
sys
.
path
.
append
(
"../"
)
import
test_compare_op
from
test_compare_op
import
*
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/ngraph/test_logical_ngraph_op.py
浏览文件 @
298ee7d2
...
...
@@ -18,11 +18,7 @@ import unittest, sys
sys
.
path
.
append
(
"../"
)
import
numpy
as
np
from
test_logical_op
import
create_test_class
create_test_class
(
'logical_and'
,
lambda
_a
,
_b
:
np
.
logical_and
(
_a
,
_b
))
create_test_class
(
'logical_or'
,
lambda
_a
,
_b
:
np
.
logical_or
(
_a
,
_b
))
create_test_class
(
'logical_not'
,
lambda
_a
:
np
.
logical_not
(
_a
),
False
)
from
test_logical_op
import
*
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录