Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
6f99bf84
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6f99bf84
编写于
11月 01, 2016
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
11月 01, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Store ksize attribute for graph transfer to SOC
Change: 137855838
上级
b00e9404
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
51 addition
and
5 deletion
+51
-5
tensorflow/core/kernels/hexagon/graph_transferer.cc
tensorflow/core/kernels/hexagon/graph_transferer.cc
+8
-1
tensorflow/core/kernels/hexagon/graph_transferer_test.cc
tensorflow/core/kernels/hexagon/graph_transferer_test.cc
+43
-4
未找到文件。
tensorflow/core/kernels/hexagon/graph_transferer.cc
浏览文件 @
6f99bf84
...
...
@@ -31,6 +31,7 @@ static constexpr const char* const CONST_SHAPE_PREFIX = "const_shape_";
static
constexpr
const
char
*
const
PADDING_PREFIX
=
"NN_PAD_"
;
static
constexpr
const
char
*
const
PADDING_ATTR_NAME
=
"padding"
;
static
constexpr
const
char
*
const
STRIDES_ATTR_NAME
=
"strides"
;
static
constexpr
const
char
*
const
KSIZE_ATTR_NAME
=
"ksize"
;
static
constexpr
const
char
*
const
PADDING_VALID_STR
=
"VALID"
;
static
constexpr
const
char
*
const
PADDING_SAME_STR
=
"SAME"
;
...
...
@@ -192,7 +193,13 @@ void GraphTransferer::RegisterNodeWithPaddingAndStrides(
std
::
vector
<
int32
>
strides
;
context
->
GetAttr
(
STRIDES_ATTR_NAME
,
&
strides
);
const
int
stride_id
=
RegisterConstantShape
(
strides
);
std
::
vector
<
int
>
extra_inputs
{
stride_id
,
0
};
std
::
vector
<
int
>
extra_inputs
{
stride_id
};
if
(
node
.
def
().
attr
().
count
(
KSIZE_ATTR_NAME
)
>
0
)
{
std
::
vector
<
int32
>
kernel_sizes
;
context
->
GetAttr
(
KSIZE_ATTR_NAME
,
&
kernel_sizes
);
const
int
ksize_id
=
RegisterConstantShape
(
kernel_sizes
);
extra_inputs
.
push_back
(
ksize_id
);
}
AppendNodeParams
(
node
.
name
(),
id
,
node
.
type_string
(),
padding
,
node
.
num_inputs
(),
extra_inputs
,
node
.
num_outputs
());
}
...
...
tensorflow/core/kernels/hexagon/graph_transferer_test.cc
浏览文件 @
6f99bf84
...
...
@@ -58,14 +58,33 @@ static GraphDef CreateConvGraphDef() {
test
::
FillIota
<
float
>
(
&
input_data
,
1.0
f
);
ops
::
Output
input
=
ops
::
Const
(
root
.
WithOpName
(
"input"
),
ops
::
Input
::
Initializer
(
input_data
));
const
int
stride
=
1
;
Tensor
filter_data
(
DT_FLOAT
,
TensorShape
({
1
,
1
,
1
,
1
}));
test
::
FillIota
<
float
>
(
&
filter_data
,
1.0
f
);
ops
::
Output
filter
=
ops
::
Const
(
root
.
WithOpName
(
"filter"
),
ops
::
Input
::
Initializer
(
filter_data
));
const
std
::
vector
<
int
>
strides
{
1
,
1
,
1
,
1
};
ops
::
Output
conv
=
ops
::
Conv2D
(
root
.
WithOpName
(
"conv"
),
input
,
filter
,
strides
,
"SAME"
);
GraphDef
def
;
TF_CHECK_OK
(
root
.
ToGraphDef
(
&
def
));
return
def
;
}
static
GraphDef
CreatePoolGraphDef
()
{
Scope
root
=
Scope
::
NewRootScope
();
Tensor
input_data
(
DT_FLOAT
,
TensorShape
({
1
,
1
,
1
,
1
}));
test
::
FillIota
<
float
>
(
&
input_data
,
1.0
f
);
ops
::
Output
input
=
ops
::
Const
(
root
.
WithOpName
(
"input"
),
ops
::
Input
::
Initializer
(
input_data
));
Tensor
filter_data
(
DT_FLOAT
,
TensorShape
({
1
,
1
,
1
,
1
}));
test
::
FillIota
<
float
>
(
&
filter_data
,
1.0
f
);
ops
::
Output
filter
=
ops
::
Const
(
root
.
WithOpName
(
"filter"
),
ops
::
Input
::
Initializer
(
filter_data
));
const
std
::
vector
<
int
>
ksize
{
1
,
1
,
1
,
1
};
const
std
::
vector
<
int
>
padding
{
0
,
0
,
0
,
0
};
ops
::
Output
conv
=
ops
::
Conv2D
(
root
.
WithOpName
(
"conv"
),
input
,
filter
,
{
1
,
stride
,
stride
,
1
},
"SAME"
);
const
std
::
vector
<
int
>
strides
{
1
,
1
,
1
,
1
};
ops
::
Output
max_pool
=
ops
::
MaxPool
(
root
.
WithOpName
(
"maxpool"
),
input
,
ksize
,
strides
,
"SAME"
);
GraphDef
def
;
TF_CHECK_OK
(
root
.
ToGraphDef
(
&
def
));
return
def
;
...
...
@@ -139,9 +158,29 @@ TEST_F(GraphTransfererTest, LoadConvGraph) {
const
int
id
=
params_conv
->
id
;
EXPECT_TRUE
(
id
>
0
&&
id
<=
(
const_node_count
+
op_node_count
));
EXPECT_EQ
(
"Conv2D"
,
params_conv
->
type
);
EXPECT_EQ
(
4
,
params_conv
->
inputs_size
);
EXPECT_EQ
(
3
,
params_conv
->
inputs_size
);
EXPECT_EQ
(
1
,
params_conv
->
outputs_size
);
EXPECT_EQ
(
"NN_PAD_SAME"
,
params_conv
->
padding
);
}
TEST_F
(
GraphTransfererTest
,
LoadMaxPoolGraph
)
{
GraphDef
def
=
CreatePoolGraphDef
();
_session
->
Create
(
def
);
GraphTransferer
gt
;
gt
.
LoadGraphFromProto
(
def
);
const
int
const_node_count
=
gt
.
GetConstNodeParams
().
size
();
ASSERT_EQ
(
3
,
const_node_count
);
const
int
op_node_count
=
gt
.
GetOpNodeParams
().
size
();
ASSERT_EQ
(
1
,
op_node_count
);
const
GraphTransferer
::
NodeTransferParams
*
params_max_pool
=
FindOpNodeParams
(
gt
,
"maxpool"
);
ASSERT_TRUE
(
params_max_pool
!=
nullptr
);
const
int
id
=
params_max_pool
->
id
;
EXPECT_TRUE
(
id
>
0
&&
id
<=
(
const_node_count
+
op_node_count
));
EXPECT_EQ
(
"MaxPool"
,
params_max_pool
->
type
);
EXPECT_EQ
(
3
,
params_max_pool
->
inputs_size
);
EXPECT_EQ
(
1
,
params_max_pool
->
outputs_size
);
EXPECT_EQ
(
"NN_PAD_SAME"
,
params_max_pool
->
padding
);
}
}
// namespace tensorflow
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录