Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
7b8e31c5
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
7b8e31c5
编写于
3月 13, 2017
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
3月 13, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Factor out shape inference propagation to RemoteFusedGraphExecuteUtils
Change: 149984977
上级
6121fe5b
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
180 addition
and
84 deletion
+180
-84
tensorflow/core/kernels/hexagon/graph_transferer.cc
tensorflow/core/kernels/hexagon/graph_transferer.cc
+17
-57
tensorflow/core/kernels/hexagon/graph_transferer.h
tensorflow/core/kernels/hexagon/graph_transferer.h
+13
-13
tensorflow/core/kernels/hexagon/graph_transferer_test.cc
tensorflow/core/kernels/hexagon/graph_transferer_test.cc
+1
-1
tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
...flow/core/kernels/hexagon/hexagon_graph_execution_test.cc
+10
-2
tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
+77
-0
tensorflow/core/kernels/remote_fused_graph_execute_utils.h
tensorflow/core/kernels/remote_fused_graph_execute_utils.h
+11
-0
tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc
...low/core/kernels/remote_fused_graph_execute_utils_test.cc
+51
-11
未找到文件。
tensorflow/core/kernels/hexagon/graph_transferer.cc
浏览文件 @
7b8e31c5
...
...
@@ -57,7 +57,7 @@ static string ToString(T val) {
/**
* graph loading functions
* - LoadGraphFromProto
* - LoadGraphFromProtoFile
* - LoadGraphFromPro
p
toFile
* These functions read a graph definition and store parameters
* of node to transfer the graph to SOC.
*/
...
...
@@ -67,60 +67,19 @@ Status GraphTransferer::LoadGraphFromProto(
const
std
::
vector
<
std
::
pair
<
string
,
Tensor
>>&
input_node_info_list
,
const
std
::
vector
<
string
>&
output_node_names
,
const
bool
shape_inference_for_unknown_shape
,
const
OutputTensor
Map
&
output_tensor_map
)
{
const
TensorShape
Map
&
output_tensor_map
)
{
ImportGraphDefOptions
opts
;
Graph
graph
(
OpRegistry
::
Global
());
ShapeRefiner
shape_refiner
(
graph
.
versions
().
producer
(),
graph
.
op_registry
());
VLOG
(
1
)
<<
"Start import graph"
;
Status
status
=
ImportGraphDef
(
opts
,
graph_def
,
&
graph
,
&
shape_refiner
);
if
(
!
status
.
ok
())
{
VLOG
(
1
)
<<
"Failed to import graph "
<<
status
.
ToString
();
return
status
;
}
if
(
shape_inference_for_unknown_shape
&&
!
input_node_info_list
.
empty
())
{
auto
visit
=
[
&
shape_refiner
,
&
input_node_info_list
,
&
status
](
Node
*
node
)
{
if
(
!
status
.
ok
())
{
return
;
}
CHECK_NE
(
node
,
nullptr
);
// If we visit an input node, we use the shape provided and set the
// shape accordingly.
bool
is_input_node
=
false
;
for
(
const
std
::
pair
<
string
,
Tensor
>&
input_node_info
:
input_node_info_list
)
{
if
(
node
->
name
()
==
input_node_info
.
first
)
{
shape_inference
::
InferenceContext
*
context
=
shape_refiner
.
GetContext
(
node
);
shape_inference
::
ShapeHandle
handle
;
status
=
context
->
MakeShapeFromTensorShape
(
input_node_info
.
second
.
shape
(),
&
handle
);
// TODO(b/32704451): Don't just ignore this status!
shape_refiner
.
SetShape
(
node
,
0
,
handle
).
IgnoreError
();
is_input_node
=
true
;
}
if
(
!
status
.
ok
())
{
break
;
}
}
// If not an input node call AddNode() that recomputes the shape.
if
(
!
is_input_node
&&
status
.
ok
())
{
status
=
shape_refiner
.
AddNode
(
node
);
if
(
!
status
.
ok
())
{
VLOG
(
1
)
<<
"Shape inference failed for node: "
<<
node
->
name
();
}
}
};
// Runs a reverse DFS over the entire graph setting the shape for the input
// nodes provided and then recomputing the shape of all the nodes downstream
// from them. The "visit" function is executed for each node after all its
// parents have been visited.
ReverseDFS
(
graph
,
{},
visit
);
if
(
shape_inference_for_unknown_shape
)
{
status
=
RemoteFusedGraphExecuteUtils
::
PropagateShapeInference
(
graph_def
,
input_node_info_list
,
&
graph
,
&
shape_refiner
);
if
(
!
status
.
ok
())
{
VLOG
(
1
)
<<
"Failed to run shape inference: "
<<
status
.
ToString
();
return
status
;
}
}
...
...
@@ -149,6 +108,7 @@ Status GraphTransferer::LoadGraphFromProto(
return
status
;
}
}
SortParams
(
output_node_names
);
for
(
const
std
::
pair
<
string
,
Tensor
>&
input_node_info
:
...
...
@@ -319,7 +279,7 @@ bool GraphTransferer::AreAllInputsCached(const Node& node) const {
Status
GraphTransferer
::
RegisterNode
(
const
IGraphTransferOpsDefinitions
&
ops_definitions
,
const
ShapeRefiner
&
shape_refiner
,
const
OutputTensor
Map
&
output_tensor_map
,
const
ShapeRefiner
&
shape_refiner
,
const
TensorShape
Map
&
output_tensor_map
,
const
Node
&
node
,
const
std
::
vector
<
std
::
pair
<
string
,
Tensor
>>&
input_node_info_list
,
const
std
::
vector
<
string
>&
output_node_names
)
{
...
...
@@ -352,7 +312,7 @@ Status GraphTransferer::RegisterNode(
void
GraphTransferer
::
RegisterConstantNode
(
const
ShapeRefiner
&
shape_refiner
,
const
Node
&
node
,
const
OutputTensor
Map
&
output_tensor_map
)
{
const
TensorShape
Map
&
output_tensor_map
)
{
VLOG
(
1
)
<<
"Register constant node: "
<<
node
.
name
();
CHECK_EQ
(
node_name_to_id_cache_map_
.
count
(
node
.
name
()),
1
);
const
int
id
=
node_name_to_id_cache_map_
[
node
.
name
()];
...
...
@@ -439,7 +399,7 @@ bool GraphTransferer::HasPaddingAndStrides(const Node& node) {
}
bool
GraphTransferer
::
IsNodeFlattenReshape
(
const
Node
&
node
,
const
OutputTensor
Map
&
output_tensor_map
,
const
Node
&
node
,
const
TensorShape
Map
&
output_tensor_map
,
const
ShapeRefiner
&
shape_refiner
)
{
// Check if node is reshape op
if
(
node
.
type_string
()
!=
RESHAPE_NODE_TYPE_STRING
)
{
...
...
@@ -477,7 +437,7 @@ bool GraphTransferer::IsNodeFlattenReshape(
void
GraphTransferer
::
RegisterNodeWithPaddingAndStrides
(
const
IGraphTransferOpsDefinitions
&
ops_definitions
,
const
ShapeRefiner
&
shape_refiner
,
const
OutputTensor
Map
&
output_tensor_map
,
const
ShapeRefiner
&
shape_refiner
,
const
TensorShape
Map
&
output_tensor_map
,
const
Node
&
node
)
{
CHECK_EQ
(
node_name_to_id_cache_map_
.
count
(
node
.
name
()),
1
);
const
int
id
=
node_name_to_id_cache_map_
[
node
.
name
()];
...
...
@@ -512,7 +472,7 @@ void GraphTransferer::RegisterNodeWithPaddingAndStrides(
void
GraphTransferer
::
RegisterInputNode
(
const
IGraphTransferOpsDefinitions
&
ops_definitions
,
const
ShapeRefiner
&
shape_refiner
,
const
OutputTensor
Map
&
output_tensor_map
,
const
ShapeRefiner
&
shape_refiner
,
const
TensorShape
Map
&
output_tensor_map
,
const
Node
&
node
)
{
VLOG
(
1
)
<<
"Register input node: "
<<
node
.
name
();
CHECK_EQ
(
node_name_to_id_cache_map_
.
count
(
node
.
name
()),
1
);
...
...
@@ -530,7 +490,7 @@ void GraphTransferer::RegisterInputNode(
void
GraphTransferer
::
RegisterFlattenNode
(
const
IGraphTransferOpsDefinitions
&
ops_definitions
,
const
ShapeRefiner
&
shape_refiner
,
const
OutputTensor
Map
&
output_tensor_map
,
const
ShapeRefiner
&
shape_refiner
,
const
TensorShape
Map
&
output_tensor_map
,
const
Node
&
node
)
{
VLOG
(
1
)
<<
"Register flatten node: "
<<
node
.
name
();
CHECK_EQ
(
node_name_to_id_cache_map_
.
count
(
node
.
name
()),
1
);
...
...
@@ -547,7 +507,7 @@ void GraphTransferer::RegisterFlattenNode(
void
GraphTransferer
::
RegisterGenericNode
(
const
IGraphTransferOpsDefinitions
&
ops_definitions
,
const
ShapeRefiner
&
shape_refiner
,
const
OutputTensor
Map
&
output_tensor_map
,
const
ShapeRefiner
&
shape_refiner
,
const
TensorShape
Map
&
output_tensor_map
,
const
Node
&
node
)
{
VLOG
(
1
)
<<
"Register generic node: "
<<
node
.
name
();
CHECK_EQ
(
node_name_to_id_cache_map_
.
count
(
node
.
name
()),
1
);
...
...
@@ -569,7 +529,7 @@ Status GraphTransferer::RegisterNodeIfAllInputsAreCached(
const
bool
only_register_const_node
,
const
std
::
vector
<
std
::
pair
<
string
,
Tensor
>>&
input_node_info_list
,
const
std
::
vector
<
string
>&
output_node_names
,
const
OutputTensor
Map
&
output_tensor_map
)
{
const
TensorShape
Map
&
output_tensor_map
)
{
if
(
only_register_const_node
&&
!
node
.
IsConstant
())
{
return
Status
();
}
...
...
@@ -627,7 +587,7 @@ void GraphTransferer::AppendNodeInputParams(
}
void
GraphTransferer
::
AppendNodeOutputParams
(
const
ShapeRefiner
&
shape_refiner
,
const
OutputTensor
Map
&
output_tensor_map
,
const
ShapeRefiner
&
shape_refiner
,
const
TensorShape
Map
&
output_tensor_map
,
const
int
id
,
const
Node
&
node
)
{
VLOG
(
1
)
<<
"Append output params: "
<<
node
.
name
()
<<
", "
<<
node
.
num_outputs
();
...
...
@@ -670,7 +630,7 @@ void GraphTransferer::AppendNodeOutputParams(
}
void
GraphTransferer
::
AppendNodeParamsWithIoParams
(
const
ShapeRefiner
&
shape_refiner
,
const
OutputTensor
Map
&
output_tensor_map
,
const
ShapeRefiner
&
shape_refiner
,
const
TensorShape
Map
&
output_tensor_map
,
const
Node
&
node
,
const
string
&
name
,
const
int
id
,
const
string
&
type
,
const
int
type_id
,
const
int
padding
,
const
int
inputs_size
,
const
std
::
vector
<
int
>&
extra_inputs
,
const
int
outputs_size
,
...
...
@@ -757,7 +717,7 @@ GraphTransferer::ToTensorShapeArray(const TensorShape& shape) {
}
/* static */
void
GraphTransferer
::
CheckShape
(
const
OutputTensor
Map
&
output_tensor_map
,
const
string
&
node_name
,
const
TensorShape
Map
&
output_tensor_map
,
const
string
&
node_name
,
const
std
::
array
<
int64
,
SHAPE_ARRAY_SIZE
>&
expected
)
{
if
(
output_tensor_map
.
empty
())
{
// As output_tensor_map is empty, skip checking tensor shape.
...
...
tensorflow/core/kernels/hexagon/graph_transferer.h
浏览文件 @
7b8e31c5
...
...
@@ -45,7 +45,7 @@ class GraphTransferer {
static
constexpr
int
MAX_SUPPORTED_RANK
=
4
;
// TODO(satok): Remove. Use proto definition instead.
static
constexpr
int
SHAPE_ARRAY_SIZE
=
MAX_SUPPORTED_RANK
;
using
OutputTensor
Map
=
RemoteFusedGraphExecuteUtils
::
TensorShapeMap
;
using
TensorShape
Map
=
RemoteFusedGraphExecuteUtils
::
TensorShapeMap
;
GraphTransferer
()
=
default
;
...
...
@@ -58,7 +58,7 @@ class GraphTransferer {
const
std
::
vector
<
std
::
pair
<
string
,
Tensor
>>&
input_node_info_list
,
const
std
::
vector
<
string
>&
output_node_names
,
const
bool
shape_inference_for_unkown_shape
,
const
OutputTensor
Map
&
output_tensor_map
);
const
TensorShape
Map
&
output_tensor_map
);
// Load graph structure into GraphTransferer from protobuf file
// TODO(satok): Pass a pair of TensorShape and DataType instead of
...
...
@@ -107,12 +107,12 @@ class GraphTransferer {
Status
RegisterNode
(
const
IGraphTransferOpsDefinitions
&
ops_definitions
,
const
ShapeRefiner
&
shape_refiner
,
const
OutputTensor
Map
&
output_tensor_map
,
const
Node
&
node
,
const
TensorShape
Map
&
output_tensor_map
,
const
Node
&
node
,
const
std
::
vector
<
std
::
pair
<
string
,
Tensor
>>&
input_node_info_list
,
const
std
::
vector
<
string
>&
output_node_names
);
void
RegisterConstantNode
(
const
ShapeRefiner
&
shape_refiner
,
const
Node
&
node
,
const
OutputTensor
Map
&
output_tensor_map
);
const
TensorShape
Map
&
output_tensor_map
);
int
RegisterConstantShape
(
const
std
::
vector
<
int
>&
shape
);
...
...
@@ -122,27 +122,27 @@ class GraphTransferer {
// TODO(satok): Remove this method once generic reshape op is implemented in
// SOC
bool
IsNodeFlattenReshape
(
const
Node
&
node
,
const
OutputTensor
Map
&
output_tensor_map
,
const
TensorShape
Map
&
output_tensor_map
,
const
ShapeRefiner
&
shape_refiner
);
void
RegisterNodeWithPaddingAndStrides
(
const
IGraphTransferOpsDefinitions
&
ops_definitions
,
const
ShapeRefiner
&
shape_refiner
,
const
OutputTensor
Map
&
output_tensor_map
,
const
Node
&
node
);
const
TensorShape
Map
&
output_tensor_map
,
const
Node
&
node
);
void
RegisterInputNode
(
const
IGraphTransferOpsDefinitions
&
ops_definitions
,
const
ShapeRefiner
&
shape_refiner
,
const
OutputTensor
Map
&
output_tensor_map
,
const
TensorShape
Map
&
output_tensor_map
,
const
Node
&
node
);
void
RegisterFlattenNode
(
const
IGraphTransferOpsDefinitions
&
ops_definitions
,
const
ShapeRefiner
&
shape_refiner
,
const
OutputTensor
Map
&
output_tensor_map
,
const
TensorShape
Map
&
output_tensor_map
,
const
Node
&
node
);
void
RegisterGenericNode
(
const
IGraphTransferOpsDefinitions
&
ops_definitions
,
const
ShapeRefiner
&
shape_refiner
,
const
OutputTensor
Map
&
output_tensor_map
,
const
TensorShape
Map
&
output_tensor_map
,
const
Node
&
node
);
Status
RegisterNodeIfAllInputsAreCached
(
...
...
@@ -151,7 +151,7 @@ class GraphTransferer {
const
bool
only_register_const_node
,
const
std
::
vector
<
std
::
pair
<
string
,
Tensor
>>&
input_node_info_list
,
const
std
::
vector
<
string
>&
output_node_names
,
const
OutputTensor
Map
&
output_tensor_map
);
const
TensorShape
Map
&
output_tensor_map
);
void
AppendNodeParams
(
const
string
&
name
,
const
int
id
,
const
string
&
type
,
const
int
type_id
,
const
int
padding
,
...
...
@@ -163,7 +163,7 @@ class GraphTransferer {
const
std
::
vector
<
int
>&
extra_inputs
);
void
AppendNodeOutputParams
(
const
ShapeRefiner
&
shape_refiner
,
const
OutputTensor
Map
&
output_tensor_map
,
const
TensorShape
Map
&
output_tensor_map
,
const
int
id
,
const
Node
&
node
);
static
std
::
array
<
int64
,
SHAPE_ARRAY_SIZE
>
BuildShapeArray
(
...
...
@@ -172,7 +172,7 @@ class GraphTransferer {
void
AppendNodeParamsWithIoParams
(
const
ShapeRefiner
&
shape_refiner
,
const
OutputTensor
Map
&
output_tensor_map
,
const
Node
&
node
,
const
TensorShape
Map
&
output_tensor_map
,
const
Node
&
node
,
const
string
&
name
,
const
int
id
,
const
string
&
type
,
const
int
type_id
,
const
int
padding
,
const
int
inputs_size
,
const
std
::
vector
<
int
>&
extra_inputs
,
const
int
outputs_size
,
...
...
@@ -183,7 +183,7 @@ class GraphTransferer {
static
string
ToPaddingDebugString
(
int
padding
);
static
void
CheckShape
(
const
OutputTensor
Map
&
output_tensor_map
,
static
void
CheckShape
(
const
TensorShape
Map
&
output_tensor_map
,
const
string
&
node_name
,
const
std
::
array
<
int64
,
SHAPE_ARRAY_SIZE
>&
actual
);
...
...
tensorflow/core/kernels/hexagon/graph_transferer_test.cc
浏览文件 @
7b8e31c5
...
...
@@ -49,7 +49,7 @@ class GraphTransfererTest : public ::testing::Test {
static
const
std
::
vector
<
string
>
OP_TYPES
{
"INPUT"
,
"OUTPUT"
,
"Conv2D"
,
"MaxPool"
,
"NoOp"
,
"Add"
,
"Const"
,
"Softmax"
};
const
GraphTransferer
::
OutputTensor
Map
EMPTY_OUTPUT_TENSOR_MAP
;
const
RemoteFusedGraphExecuteUtils
::
TensorShape
Map
EMPTY_OUTPUT_TENSOR_MAP
;
class
TestGraphTransferOpsDefinitions
:
public
IGraphTransferOpsDefinitions
{
public:
...
...
tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
浏览文件 @
7b8e31c5
...
...
@@ -56,6 +56,7 @@ constexpr const char* const FUSED_MODEL_FILENAME =
"tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb"
;
constexpr
const
char
*
const
REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME
=
"remote_fused_graph_execute_node"
;
constexpr
bool
USE_SHAPE_INFERENCE
=
false
;
const
bool
DBG_DUMP_FLOAT_DATA
=
false
;
const
int
WIDTH
=
299
;
...
...
@@ -282,11 +283,18 @@ TEST(GraphTransferer,
RemoteFusedGraphExecuteUtils
::
TensorShapeMap
output_tensor_info
;
GraphTransferer
gt
;
gt
.
EnableStrictCheckMode
(
false
);
profile_utils
::
CpuUtils
::
EnableClockCycleProfiling
(
true
);
ClockCycleProfiler
prof
;
prof
.
Start
();
Status
status
=
gt
.
LoadGraphFromProtoFile
(
*
ops_definitions
,
MODEL_FILENAME
,
inputs
,
output_node_names
,
false
/* is_text_proto */
,
false
/* shape_inference_for_unknown_shape */
,
true
/* dry_run_for_unknown_shape */
,
&
output_tensor_info
);
false
,
// is_text_proto
USE_SHAPE_INFERENCE
,
// shape_inference_for_unknown_shape
!
USE_SHAPE_INFERENCE
,
// dry_run_for_unknown_shape
&
output_tensor_info
);
ASSERT_TRUE
(
status
.
ok
())
<<
status
;
prof
.
Stop
();
prof
.
DumpStatistics
(
"LoadGraphFromProtoFile"
);
std
::
vector
<
float
>
img_floats
;
LoadImage
(
&
img_floats
);
...
...
tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
浏览文件 @
7b8e31c5
...
...
@@ -15,9 +15,12 @@ limitations under the License.
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include <algorithm>
#include <utility>
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
...
...
@@ -222,4 +225,78 @@ RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() {
return
true
;
}
/* static */
Status
RemoteFusedGraphExecuteUtils
::
PropagateShapeInference
(
const
GraphDef
&
graph_def
,
const
std
::
vector
<
std
::
pair
<
string
,
Tensor
>>&
input_node_info_list
,
Graph
*
graph
,
ShapeRefiner
*
shape_refiner
)
{
Status
status
;
auto
visit
=
[
&
shape_refiner
,
&
input_node_info_list
,
&
status
](
Node
*
node
)
{
if
(
!
status
.
ok
())
{
return
;
}
CHECK_NE
(
node
,
nullptr
);
// If we visit an input node, we use the shape provided and set the
// shape accordingly.
bool
is_input_node
=
false
;
for
(
const
std
::
pair
<
string
,
Tensor
>&
input_node_info
:
input_node_info_list
)
{
if
(
node
->
name
()
==
input_node_info
.
first
)
{
shape_inference
::
InferenceContext
*
context
=
shape_refiner
->
GetContext
(
node
);
shape_inference
::
ShapeHandle
handle
;
status
=
context
->
MakeShapeFromTensorShape
(
input_node_info
.
second
.
shape
(),
&
handle
);
shape_refiner
->
SetShape
(
node
,
0
,
handle
);
is_input_node
=
true
;
}
if
(
!
status
.
ok
())
{
break
;
}
}
// If not an input node call AddNode() that recomputes the shape.
if
(
!
is_input_node
&&
status
.
ok
())
{
status
=
shape_refiner
->
AddNode
(
node
);
if
(
!
status
.
ok
())
{
VLOG
(
1
)
<<
"Shape inference failed for node: "
<<
node
->
name
();
}
}
};
ReverseDFS
(
*
graph
,
{},
visit
);
return
status
;
}
/* static */
Status
RemoteFusedGraphExecuteUtils
::
BuildTensorShapeMapFromGraph
(
const
Graph
&
graph
,
const
ShapeRefiner
&
shape_refiner
,
TensorShapeMap
*
tensor_shape_map
)
{
for
(
int
i
=
0
;
i
<
graph
.
num_node_ids
();
++
i
)
{
const
Node
*
node
=
graph
.
FindNodeId
(
i
);
CHECK_NE
(
node
,
nullptr
);
for
(
int
j
=
0
;
j
<
node
->
num_outputs
();
++
j
)
{
const
int
output_index
=
j
;
const
DataType
dt
=
node
->
output_type
(
output_index
);
shape_inference
::
InferenceContext
*
context
=
shape_refiner
.
GetContext
(
node
);
CHECK_NE
(
context
,
nullptr
);
shape_inference
::
ShapeHandle
shape_handle
=
context
->
output
(
output_index
);
if
(
context
->
RankKnown
(
shape_handle
))
{
TensorShape
ts
;
for
(
int
k
=
0
;
k
<
context
->
Rank
(
shape_handle
);
++
k
)
{
shape_inference
::
DimensionHandle
dh
=
context
->
Dim
(
shape_handle
,
k
);
CHECK
(
context
->
ValueKnown
(
dh
));
ts
.
AddDim
(
context
->
Value
(
dh
));
}
const
string
&
node_name
=
node
->
name
();
CHECK
(
tensor_shape_map
->
count
(
node_name
)
==
0
);
tensor_shape_map
->
emplace
(
node_name
,
std
::
make_pair
(
dt
,
ts
));
}
else
{
return
errors
::
InvalidArgument
(
"Graph contains unknow shapes"
);
}
}
}
return
Status
::
OK
();
}
}
// namespace tensorflow
tensorflow/core/kernels/remote_fused_graph_execute_utils.h
浏览文件 @
7b8e31c5
...
...
@@ -20,6 +20,8 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
...
...
@@ -89,6 +91,15 @@ class RemoteFusedGraphExecuteUtils {
const
std
::
vector
<
TensorShape
>&
shapes
,
NodeDef
*
node_def
);
static
Status
PropagateShapeInference
(
const
GraphDef
&
graph_def
,
const
std
::
vector
<
std
::
pair
<
string
,
Tensor
>>&
input_node_info_list
,
Graph
*
graph
,
ShapeRefiner
*
shape_refiner
);
static
Status
BuildTensorShapeMapFromGraph
(
const
Graph
&
graph
,
const
ShapeRefiner
&
shape_refiner
,
TensorShapeMap
*
tensor_shape_map
);
private:
static
ExecutorBuildRegistry
*
GetExecutorBuildRegistry
();
...
...
tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc
浏览文件 @
7b8e31c5
...
...
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
...
...
@@ -100,27 +101,66 @@ TEST(RemoteFusedGraphExecuteUtils, DryRunAddGraphForAllNodes) {
// Setup dryrun arguments
const
std
::
vector
<
std
::
pair
<
string
,
Tensor
>>
inputs
{
input_node_info_a
};
RemoteFusedGraphExecuteUtils
::
TensorShapeMap
output_tensor_info
;
RemoteFusedGraphExecuteUtils
::
TensorShapeMap
tensor_shape_map
;
GraphDef
def
=
RemoteFusedGraphExecuteOpTestUtils
::
BuildAddGraph
(
NAME_A
,
NODE_A_VAL
,
NAME_B
,
NODE_B_VAL
,
NAME_A_PLUS_B
);
// dryrun
const
Status
status
=
RemoteFusedGraphExecuteUtils
::
DryRunInferenceForAllNode
(
def
,
inputs
,
false
/* initialize_by_zero */
,
&
output_tensor_info
);
def
,
inputs
,
false
/* initialize_by_zero */
,
&
tensor_shape_map
);
ASSERT_TRUE
(
status
.
ok
())
<<
status
;
// Assert output node count
ASSERT_EQ
(
3
,
output_tensor_info
.
size
());
ASSERT_EQ
(
1
,
output_tensor_info
.
count
(
NAME_A
));
ASSERT_EQ
(
1
,
output_tensor_info
.
count
(
NAME_B
));
ASSERT_EQ
(
1
,
output_tensor_info
.
count
(
NAME_A_PLUS_B
));
EXPECT_EQ
(
DT_FLOAT
,
output_tensor_info
.
at
(
NAME_B
).
first
);
EXPECT_EQ
(
DT_FLOAT
,
output_tensor_info
.
at
(
NAME_A_PLUS_B
).
first
);
const
TensorShape
&
shape_b
=
output_tensor_info
.
at
(
NAME_B
).
second
;
const
TensorShape
&
shape_a_b
=
output_tensor_info
.
at
(
NAME_A_PLUS_B
).
second
;
ASSERT_EQ
(
3
,
tensor_shape_map
.
size
());
ASSERT_EQ
(
1
,
tensor_shape_map
.
count
(
NAME_A
));
ASSERT_EQ
(
1
,
tensor_shape_map
.
count
(
NAME_B
));
ASSERT_EQ
(
1
,
tensor_shape_map
.
count
(
NAME_A_PLUS_B
));
EXPECT_EQ
(
DT_FLOAT
,
tensor_shape_map
.
at
(
NAME_B
).
first
);
EXPECT_EQ
(
DT_FLOAT
,
tensor_shape_map
.
at
(
NAME_A_PLUS_B
).
first
);
const
TensorShape
&
shape_b
=
tensor_shape_map
.
at
(
NAME_B
).
second
;
const
TensorShape
&
shape_a_b
=
tensor_shape_map
.
at
(
NAME_A_PLUS_B
).
second
;
EXPECT_EQ
(
0
,
shape_b
.
dims
());
EXPECT_EQ
(
0
,
shape_a_b
.
dims
());
}
TEST
(
RemoteFusedGraphExecuteUtils
,
PropagateAndBuildTensorShapeMap
)
{
std
::
pair
<
string
,
Tensor
>
input_node_info_a
;
input_node_info_a
.
first
=
NAME_A
;
input_node_info_a
.
second
=
Tensor
(
DT_FLOAT
,
{});
input_node_info_a
.
second
.
scalar
<
float
>
()()
=
NODE_A_VAL
;
std
::
pair
<
string
,
Tensor
>
input_node_info_b
;
input_node_info_b
.
first
=
NAME_B
;
input_node_info_b
.
second
=
Tensor
(
DT_FLOAT
,
{});
input_node_info_b
.
second
.
scalar
<
float
>
()()
=
NODE_B_VAL
;
const
std
::
vector
<
std
::
pair
<
string
,
Tensor
>>
inputs
{
input_node_info_a
,
input_node_info_b
};
RemoteFusedGraphExecuteUtils
::
TensorShapeMap
tensor_shape_map
;
GraphDef
def
=
RemoteFusedGraphExecuteOpTestUtils
::
BuildAddGraph
(
NAME_A
,
NODE_A_VAL
,
NAME_B
,
NODE_B_VAL
,
NAME_A_PLUS_B
);
ImportGraphDefOptions
opts
;
Graph
graph
(
OpRegistry
::
Global
());
ShapeRefiner
shape_refiner
(
graph
.
versions
().
producer
(),
graph
.
op_registry
());
Status
status
=
ImportGraphDef
(
opts
,
def
,
&
graph
,
&
shape_refiner
);
ASSERT_TRUE
(
RemoteFusedGraphExecuteUtils
::
PropagateShapeInference
(
def
,
inputs
,
&
graph
,
&
shape_refiner
)
.
ok
());
ASSERT_TRUE
(
RemoteFusedGraphExecuteUtils
::
BuildTensorShapeMapFromGraph
(
graph
,
shape_refiner
,
&
tensor_shape_map
)
.
ok
());
ASSERT_EQ
(
3
,
tensor_shape_map
.
size
());
ASSERT_EQ
(
1
,
tensor_shape_map
.
count
(
NAME_A
));
ASSERT_EQ
(
1
,
tensor_shape_map
.
count
(
NAME_B
));
ASSERT_EQ
(
1
,
tensor_shape_map
.
count
(
NAME_A_PLUS_B
));
EXPECT_EQ
(
DT_FLOAT
,
tensor_shape_map
.
at
(
NAME_B
).
first
);
EXPECT_EQ
(
DT_FLOAT
,
tensor_shape_map
.
at
(
NAME_A_PLUS_B
).
first
);
const
TensorShape
&
shape_b
=
tensor_shape_map
.
at
(
NAME_B
).
second
;
const
TensorShape
&
shape_a_b
=
tensor_shape_map
.
at
(
NAME_A_PLUS_B
).
second
;
EXPECT_EQ
(
0
,
shape_b
.
dims
());
EXPECT_EQ
(
0
,
shape_a_b
.
dims
());
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录