Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_51669992
tensorflow
提交
06fb6333
T
tensorflow
项目概览
weixin_51669992
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
16
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
06fb6333
编写于
5月 03, 2019
作者:
J
Jiri Simsa
提交者:
TensorFlower Gardener
5月 03, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Automated rollback of commit
432de130
PiperOrigin-RevId: 246502904
上级
a85bbeb7
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
319 addition
and
581 deletion
+319
-581
tensorflow/core/framework/dataset.h
tensorflow/core/framework/dataset.h
+1
-0
tensorflow/core/grappler/optimizers/data/auto_shard.cc
tensorflow/core/grappler/optimizers/data/auto_shard.cc
+6
-7
tensorflow/core/grappler/optimizers/data/graph_utils.cc
tensorflow/core/grappler/optimizers/data/graph_utils.cc
+1
-2
tensorflow/core/grappler/optimizers/data/latency_all_edges.cc
...orflow/core/grappler/optimizers/data/latency_all_edges.cc
+0
-8
tensorflow/core/grappler/optimizers/data/rebatch.cc
tensorflow/core/grappler/optimizers/data/rebatch.cc
+21
-4
tensorflow/core/kernels/data/BUILD
tensorflow/core/kernels/data/BUILD
+9
-21
tensorflow/core/kernels/data/dataset_ops.cc
tensorflow/core/kernels/data/dataset_ops.cc
+2
-1
tensorflow/core/kernels/data/dataset_utils.cc
tensorflow/core/kernels/data/dataset_utils.cc
+160
-1
tensorflow/core/kernels/data/dataset_utils.h
tensorflow/core/kernels/data/dataset_utils.h
+7
-0
tensorflow/core/kernels/data/experimental/BUILD
tensorflow/core/kernels/data/experimental/BUILD
+16
-20
tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.cc
...w/core/kernels/data/experimental/auto_shard_dataset_op.cc
+34
-68
tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc
...flow/core/kernels/data/experimental/rebatch_dataset_op.cc
+26
-55
tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
...low/core/kernels/data/experimental/snapshot_dataset_op.cc
+2
-1
tensorflow/core/kernels/data/graph_rewrite_dataset.cc
tensorflow/core/kernels/data/graph_rewrite_dataset.cc
+0
-240
tensorflow/core/kernels/data/graph_rewrite_dataset.h
tensorflow/core/kernels/data/graph_rewrite_dataset.h
+0
-95
tensorflow/core/kernels/data/optimize_dataset_op.cc
tensorflow/core/kernels/data/optimize_dataset_op.cc
+34
-58
未找到文件。
tensorflow/core/framework/dataset.h
浏览文件 @
06fb6333
...
...
@@ -683,6 +683,7 @@ class DatasetBase : public core::RefCounted {
protected:
friend
Status
AsGraphDef
(
OpKernelContext
*
ctx
,
const
DatasetBase
*
dataset
,
SerializationContext
&&
serialization_ctx
,
GraphDef
*
graph_def
);
// For access to graph related members.
friend
class
CapturedFunction
;
...
...
tensorflow/core/grappler/optimizers/data/auto_shard.cc
浏览文件 @
06fb6333
...
...
@@ -50,7 +50,8 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
"ZipDataset"
};
constexpr
std
::
array
<
const
char
*
,
22
>
kPassThroughOps
=
{
constexpr
std
::
array
<
const
char
*
,
23
>
kPassThroughOps
=
{
"_Retval"
,
"BatchDataset"
,
"BatchDatasetV2"
,
"ExperimentalMapAndBatchDataset"
,
...
...
@@ -285,16 +286,14 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index,
// function in flat_map.
if
(
IsDatasetNodeOfType
(
node
,
kFuncDatasetOps
)
&&
ReaderOpInFunction
(
node
,
*
flib
))
{
TF_RETURN_IF_ERROR
(
ProcessDatasetSourceNode
(
graph
,
node
,
nodes_to_delete
,
num_workers
,
index
));
return
Status
::
OK
();
return
ProcessDatasetSourceNode
(
graph
,
node
,
nodes_to_delete
,
num_workers
,
index
);
}
if
(
IsDatasetNodeOfType
(
node
,
kReaderDatasetOps
))
{
// We reached a reader dataset directly and we try to shard input 0.
TF_RETURN_IF_ERROR
(
ProcessDatasetSourceNode
(
graph
,
node
,
nodes_to_delete
,
num_workers
,
index
));
return
Status
::
OK
();
return
ProcessDatasetSourceNode
(
graph
,
node
,
nodes_to_delete
,
num_workers
,
index
);
}
if
(
!
IsDatasetNodeOfType
(
node
,
kPassThroughOps
))
{
...
...
tensorflow/core/grappler/optimizers/data/graph_utils.cc
浏览文件 @
06fb6333
...
...
@@ -301,12 +301,11 @@ Status EnsureNodeNamesUnique(Graph* g) {
return
Status
::
OK
();
}
// Tries to find a
Sink
node in the graph. A sink node is defined as a node
// Tries to find a
"sink"
node in the graph. A sink node is defined as a node
// that has at least one input and no outputs. If there are multiple of these,
// this might return any one of them. This is useful to identify the final
// Dataset op in the graph but in some cases there might be multiple Identity
// ops added to the end and this would return the last Identity op in that case.
Status
FindSinkNode
(
const
GraphDef
&
graph_def
,
NodeDef
*
sink_node
)
{
absl
::
flat_hash_map
<
string
,
int
>
all_node_names
;
absl
::
flat_hash_map
<
string
,
int
>
node_input_map
;
...
...
tensorflow/core/grappler/optimizers/data/latency_all_edges.cc
浏览文件 @
06fb6333
...
...
@@ -83,14 +83,6 @@ Status LatencyAllEdges::OptimizeAndCollectStats(Cluster* cluster,
// node corresponds to a `Dataset` op.
continue
;
}
MutableGraphView
::
OutputPort
output_port
=
graph
.
GetOutputPort
(
node
.
name
(),
0
);
auto
fanout
=
graph
.
GetFanout
(
output_port
);
if
(
fanout
.
size
()
>
1
)
{
LOG
(
WARNING
)
<<
node
.
name
()
<<
" has fanout size "
<<
fanout
.
size
();
continue
;
}
// fanout will have size 0 for last dataset node in the pipeline.
NodeDef
*
latency_node
=
graph
.
AddNode
(
MakeLatencyNode
(
node
,
&
graph
));
TF_RETURN_IF_ERROR
(
graph
.
UpdateFanouts
(
node
.
name
(),
latency_node
->
name
()));
stats
->
num_changes
++
;
...
...
tensorflow/core/grappler/optimizers/data/rebatch.cc
浏览文件 @
06fb6333
...
...
@@ -41,8 +41,9 @@ Status RebatchOptimizer::Init(
namespace
{
constexpr
char
kCastOp
[]
=
"Cast"
;
constexpr
char
kRealDivOp
[]
=
"RealDiv"
;
constexpr
char
kConstOp
[]
=
"Const"
;
constexpr
char
kIdentityOp
[]
=
"Identity"
;
constexpr
char
kRealDivOp
[]
=
"RealDiv"
;
constexpr
std
::
array
<
const
char
*
,
5
>
kBatchDatasetOps
=
{
"BatchDataset"
,
...
...
@@ -135,12 +136,24 @@ bool IsDatasetNodeOfType(const NodeDef& node,
return
false
;
}
Status
UpdateOutputShapes
(
const
string
&
node_name
,
int64
num_workers
,
MutableGraphView
*
graph
)
{
NodeDef
*
node
=
graph
->
GetNode
(
node_name
);
if
(
node
->
op
()
==
kIdentityOp
)
{
return
Status
::
OK
();
}
AttrValue
output_shapes
=
node
->
attr
().
at
(
"output_shapes"
);
for
(
auto
&
shape
:
*
output_shapes
.
mutable_list
()
->
mutable_shape
())
{
shape
.
mutable_dim
(
0
)
->
set_size
(
shape
.
dim
(
0
).
size
()
/
num_workers
);
}
(
*
node
->
mutable_attr
())[
"output_shapes"
]
=
output_shapes
;
return
Status
::
OK
();
}
// Given a "batch" dataset node, modifies the batch_size input to divide the
// current batch size by num_workers.
Status
MutateBatchSize
(
const
NodeDef
&
node
,
int64
num_workers
,
MutableGraphView
*
graph
)
{
// TODO(rohanj): Fix up the output_shapes attribute as well. For this Dataset
// as well as all the downstream datasets.
// For all the batching datasets the batch_size is input number 1 except for
// MapAndBatchDataset.
int64
batch_size_arg_index
=
1
;
...
...
@@ -194,7 +207,8 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers,
FunctionLibraryDefinition
*
flib
,
MutableGraphView
*
graph
)
{
if
(
IsDatasetNodeOfType
(
node
,
kBatchDatasetOps
))
{
return
MutateBatchSize
(
node
,
num_workers
,
graph
);
TF_RETURN_IF_ERROR
(
MutateBatchSize
(
node
,
num_workers
,
graph
));
TF_RETURN_IF_ERROR
(
UpdateOutputShapes
(
node
.
name
(),
num_workers
,
graph
));
}
else
if
(
IsDatasetNodeOfType
(
node
,
kMultipleInputsDatasetOps
))
{
// For all multiple input datasets, all inputs are datasets themselves.
for
(
int
i
=
0
;
i
<
node
.
input_size
();
++
i
)
{
...
...
@@ -202,12 +216,14 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers,
TF_RETURN_IF_ERROR
(
RecursivelyHandleOp
(
*
input_node
,
num_workers
,
flib
,
graph
));
}
TF_RETURN_IF_ERROR
(
UpdateOutputShapes
(
node
.
name
(),
num_workers
,
graph
));
}
else
if
(
IsDatasetNodeOfType
(
node
,
kPassThroughOps
))
{
// For all the dataset ops that are pass through, the input dataset is
// input 0.
NodeDef
*
input_node
=
graph_utils
::
GetInputNode
(
node
,
*
graph
,
0
);
TF_RETURN_IF_ERROR
(
RecursivelyHandleOp
(
*
input_node
,
num_workers
,
flib
,
graph
));
TF_RETURN_IF_ERROR
(
UpdateOutputShapes
(
node
.
name
(),
num_workers
,
graph
));
}
else
if
(
IsDatasetNodeOfType
(
node
,
kFuncDatasetOps
))
{
const
string
func_name
=
node
.
attr
().
at
(
"f"
).
func
().
name
();
const
FunctionDef
*
fdef
=
flib
->
Find
(
func_name
);
...
...
@@ -233,6 +249,7 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers,
// Replace optimized function with a new FunctionDef.
TF_RETURN_IF_ERROR
(
flib
->
ReplaceFunction
(
func_name
,
optimized_func
));
TF_RETURN_IF_ERROR
(
UpdateOutputShapes
(
node
.
name
(),
num_workers
,
graph
));
}
else
{
VLOG
(
2
)
<<
"Failed to optimize dataset function. Error: "
<<
s
.
error_message
();
...
...
tensorflow/core/kernels/data/BUILD
浏览文件 @
06fb6333
...
...
@@ -51,6 +51,14 @@ cc_library(
"//tensorflow/core:framework"
,
"//tensorflow/core:lib"
,
"//tensorflow/core:lib_internal"
,
"//tensorflow/core:protos_all_cc"
,
"//tensorflow/core/grappler:grappler_item"
,
"//tensorflow/core/grappler:grappler_item_builder"
,
"//tensorflow/core/grappler/clusters:virtual_cluster"
,
"//tensorflow/core/grappler/optimizers:meta_optimizer"
,
"//tensorflow/core/grappler/optimizers/data"
,
"//tensorflow/core/grappler/optimizers/data:function_utils"
,
"//tensorflow/core/grappler/optimizers/data:graph_utils"
,
],
)
...
...
@@ -936,31 +944,11 @@ tf_kernel_library(
],
)
cc_library
(
name
=
"graph_rewrite_dataset"
,
srcs
=
[
"graph_rewrite_dataset.cc"
],
hdrs
=
[
"graph_rewrite_dataset.h"
],
deps
=
[
":captured_function"
,
":dataset_utils"
,
"//tensorflow/core:core_cpu_internal"
,
"//tensorflow/core:framework"
,
"//tensorflow/core:protos_all_cc"
,
"//tensorflow/core/grappler:grappler_item"
,
"//tensorflow/core/grappler:grappler_item_builder"
,
"//tensorflow/core/grappler/clusters:virtual_cluster"
,
"//tensorflow/core/grappler/optimizers:meta_optimizer"
,
"//tensorflow/core/grappler/optimizers/data"
,
"//tensorflow/core/grappler/optimizers/data:function_utils"
,
"//tensorflow/core/grappler/optimizers/data:graph_utils"
,
],
)
tf_kernel_library
(
name
=
"optimize_dataset_op"
,
srcs
=
[
"optimize_dataset_op.cc"
],
deps
=
[
":
graph_rewrite_dataset
"
,
":
dataset_utils
"
,
"//tensorflow/core:core_cpu_internal"
,
"//tensorflow/core:dataset_ops_op_lib"
,
"//tensorflow/core:framework"
,
...
...
tensorflow/core/kernels/data/dataset_ops.cc
浏览文件 @
06fb6333
...
...
@@ -32,7 +32,8 @@ class DatasetToGraphOp : public OpKernel {
DatasetBase
*
dataset
;
OP_REQUIRES_OK
(
ctx
,
GetDatasetFromVariantTensor
(
ctx
->
input
(
0
),
&
dataset
));
GraphDef
graph_def
;
OP_REQUIRES_OK
(
ctx
,
AsGraphDef
(
ctx
,
dataset
,
&
graph_def
));
OP_REQUIRES_OK
(
ctx
,
AsGraphDef
(
ctx
,
dataset
,
SerializationContext
({}),
&
graph_def
));
Tensor
*
result
;
OP_REQUIRES_OK
(
ctx
,
ctx
->
allocate_output
(
0
,
TensorShape
({}),
&
result
));
result
->
scalar
<
string
>
()()
=
graph_def
.
SerializeAsString
();
...
...
tensorflow/core/kernels/data/dataset_utils.cc
浏览文件 @
06fb6333
...
...
@@ -17,20 +17,128 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/util/work_sharder.h"
namespace
tensorflow
{
namespace
data
{
namespace
{
void
AddFakeSinks
(
FunctionDef
*
function_def
)
{
int
counter
=
0
;
for
(
const
auto
&
output
:
function_def
->
signature
().
output_arg
())
{
NodeDef
*
node
=
function_def
->
add_node_def
();
tensorflow
::
grappler
::
function_utils
::
SetUniqueFunctionNodeName
(
strings
::
StrCat
(
"FakeSink"
,
counter
++
),
function_def
,
node
);
node
->
set_op
(
"Identity"
);
node
->
add_input
(
function_def
->
ret
().
at
(
output
.
name
()));
(
*
node
->
mutable_attr
())[
"T"
].
set_type
(
output
.
type
());
(
*
function_def
->
mutable_ret
())[
output
.
name
()]
=
strings
::
StrCat
(
node
->
name
(),
":output:0"
);
}
}
void
RemoveFakeSinks
(
FunctionDef
*
function_def
)
{
// Map from identity node names to their input tensor strings
std
::
map
<
string
,
string
>
identity_map
;
for
(
const
auto
&
node
:
function_def
->
node_def
())
{
if
(
node
.
op
()
==
"Identity"
&&
node
.
input_size
()
==
1
)
{
identity_map
[
node
.
name
()]
=
node
.
input
(
0
);
}
}
for
(
const
auto
&
output_arg
:
function_def
->
signature
().
output_arg
())
{
const
string
&
tensor
=
function_def
->
ret
().
at
(
output_arg
.
name
());
const
string
&
output_node
=
tensor
.
substr
(
0
,
tensor
.
find
(
':'
));
if
(
identity_map
.
find
(
output_node
)
!=
identity_map
.
end
())
{
(
*
function_def
->
mutable_ret
())[
output_arg
.
name
()]
=
identity_map
.
at
(
output_node
);
}
}
}
Status
ApplyRewrites
(
OpKernelContext
*
ctx
,
const
std
::
function
<
RewriterConfig
(
void
)
>
config_factory
,
bool
optimize_function_library
,
GraphDef
*
graph_def
,
string
*
output_node
)
{
// Add an identity node as the fetch node, otherwise we might get 'placeholder
// is both fed and fetched' errors in some cases when using input list with
// placeholder dataset nodes.
NodeDef
*
node
=
graph_def
->
mutable_node
()
->
Add
();
tensorflow
::
grappler
::
graph_utils
::
SetUniqueGraphNodeName
(
"Sink"
,
graph_def
,
node
);
node
->
set_op
(
"Identity"
);
node
->
add_input
(
*
output_node
);
(
*
node
->
mutable_attr
())[
"T"
].
set_type
(
DT_VARIANT
);
*
output_node
=
node
->
name
();
// Add fake sink node to graph and functions to allow rewriting the actual
// sink nodes.
//
// TODO(b/118820916): When MetaOptimizer adds provisions for function retvals
// to be optimizable, we will no longer need this.
for
(
auto
&
function_def
:
*
graph_def
->
mutable_library
()
->
mutable_function
())
{
AddFakeSinks
(
&
function_def
);
}
// Create metagraph.
MetaGraphDef
meta_graph_def
;
(
*
meta_graph_def
.
mutable_graph_def
())
=
*
graph_def
;
// Grappler determines fetch ops from collection 'train_op'.
CollectionDef
collection_def
;
auto
node_list
=
collection_def
.
mutable_node_list
();
node_list
->
add_value
(
*
output_node
);
(
*
meta_graph_def
.
mutable_collection_def
())[
"train_op"
]
=
collection_def
;
// Create Grappler item.
tensorflow
::
grappler
::
ItemConfig
item_config
;
item_config
.
apply_optimizations
=
true
;
std
::
unique_ptr
<
tensorflow
::
grappler
::
GrapplerItem
>
grappler_item
=
tensorflow
::
grappler
::
GrapplerItemFromMetaGraphDef
(
"graph"
,
meta_graph_def
,
item_config
);
grappler_item
->
optimization_options
().
optimize_function_library
=
optimize_function_library
;
std
::
unordered_map
<
string
,
tensorflow
::
DeviceProperties
>
device_map
;
tensorflow
::
grappler
::
VirtualCluster
cluster
(
device_map
);
// Run data optimizer using grappler's meta optimizer.
tensorflow
::
ConfigProto
config
;
*
config
.
mutable_graph_options
()
->
mutable_rewrite_options
()
=
config_factory
();
TF_RETURN_IF_ERROR
(
tensorflow
::
grappler
::
RunMetaOptimizer
(
*
grappler_item
,
config
,
ctx
->
device
(),
&
cluster
,
graph_def
));
// Remove fake sinks after optimizations are done.
//
// TODO(b/118820916): When MetaOptimizer adds provisions for function retvals
// to be optimizable, we will no longer need this.
for
(
auto
&
function_def
:
*
graph_def
->
mutable_library
()
->
mutable_function
())
{
RemoveFakeSinks
(
&
function_def
);
}
return
Status
::
OK
();
}
}
// anonymous namespace
Status
AsGraphDef
(
OpKernelContext
*
ctx
,
const
DatasetBase
*
dataset
,
SerializationContext
&&
serialization_ctx
,
GraphDef
*
graph_def
)
{
GraphDefBuilder
b
;
DatasetBase
::
DatasetGraphDefBuilder
db
(
&
b
);
Node
*
output_node
=
nullptr
;
SerializationContext
serialization_ctx
({});
TF_RETURN_IF_ERROR
(
db
.
AddInputDataset
(
&
serialization_ctx
,
dataset
,
&
output_node
));
// Insert a purely symbolic _Retval node to indicate to consumers which Tensor
...
...
@@ -44,6 +152,57 @@ Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
return
Status
::
OK
();
}
Status
RewriteDataset
(
OpKernelContext
*
ctx
,
const
DatasetBase
*
input
,
std
::
function
<
RewriterConfig
(
void
)
>
config_factory
,
bool
optimize_function_library
,
DatasetBase
**
rewritten_input
)
{
SerializationContext
::
Params
params
;
std
::
vector
<
std
::
pair
<
string
,
Tensor
>>
input_list
;
params
.
input_list
=
&
input_list
;
params
.
optimization_only
=
true
;
SerializationContext
serialization_ctx
(
params
);
GraphDef
graph_def
;
TF_RETURN_IF_ERROR
(
AsGraphDef
(
ctx
,
input
,
std
::
move
(
serialization_ctx
),
&
graph_def
));
string
output_node
;
for
(
const
auto
&
node
:
graph_def
.
node
())
{
if
(
node
.
op
()
==
"_Retval"
)
{
output_node
=
node
.
input
(
0
);
}
}
VLOG
(
3
)
<<
"Before graph rewrites: "
<<
graph_def
.
DebugString
();
TF_RETURN_IF_ERROR
(
ApplyRewrites
(
ctx
,
config_factory
,
optimize_function_library
,
&
graph_def
,
&
output_node
));
VLOG
(
3
)
<<
"After graph rewrites: "
<<
graph_def
.
DebugString
();
// Instantiate the optimized input pipeline by running the optimized graph
// using the optimized function library.
FunctionLibraryRuntime
*
flr
=
nullptr
;
std
::
unique_ptr
<
ProcessFunctionLibraryRuntime
>
pflr
=
nullptr
;
std
::
unique_ptr
<
FunctionLibraryDefinition
>
lib_def
=
nullptr
;
TF_RETURN_IF_ERROR
(
ctx
->
function_library
()
->
Clone
(
&
lib_def
,
&
pflr
,
&
flr
,
true
));
// Some functions may have been modified without having their names
// changed (for example, nested dataset graphs from FlatMap or
// Interleave).
TF_RETURN_IF_ERROR
(
AddToFunctionLibrary
(
lib_def
.
get
(),
graph_def
.
library
()));
Graph
graph
(
OpRegistry
::
Global
());
TF_RETURN_IF_ERROR
(
ImportGraphDef
({},
graph_def
,
&
graph
,
nullptr
));
std
::
vector
<
Tensor
>
outputs
;
GraphRunner
graph_runner
(
flr
->
device
());
TF_RETURN_IF_ERROR
(
graph_runner
.
Run
(
&
graph
,
flr
,
input_list
,
{
output_node
},
&
outputs
));
TF_RETURN_IF_ERROR
(
GetDatasetFromVariantTensor
(
outputs
[
0
],
rewritten_input
));
(
*
rewritten_input
)
->
Ref
();
return
Status
::
OK
();
}
Status
VerifyTypesMatch
(
const
DataTypeVector
&
expected
,
const
DataTypeVector
&
received
)
{
if
(
expected
.
size
()
!=
received
.
size
())
{
...
...
tensorflow/core/kernels/data/dataset_utils.h
浏览文件 @
06fb6333
...
...
@@ -23,8 +23,15 @@ namespace data {
// Returns a GraphDef representation of the given dataset.
Status
AsGraphDef
(
OpKernelContext
*
ctx
,
const
DatasetBase
*
dataset
,
SerializationContext
&&
serialization_ctx
,
GraphDef
*
graph_def
);
// Rewrites the input dataset using the given config.
Status
RewriteDataset
(
OpKernelContext
*
ctx
,
const
DatasetBase
*
input
,
std
::
function
<
RewriterConfig
(
void
)
>
config_factory
,
bool
optimize_function_library
,
DatasetBase
**
rewritten_input
);
// Returns Status::OK() if `expected` and `received` types match,
// errors::InvalidArgument otherwise.
Status
VerifyTypesMatch
(
const
DataTypeVector
&
expected
,
...
...
tensorflow/core/kernels/data/experimental/BUILD
浏览文件 @
06fb6333
...
...
@@ -10,10 +10,6 @@ load(
"//tensorflow:tensorflow.bzl"
,
"tf_kernel_library"
,
)
load
(
"//tensorflow/core:platform/default/build_config.bzl"
,
"tf_proto_library"
,
)
tf_kernel_library
(
name
=
"assert_next_dataset_op"
,
...
...
@@ -25,6 +21,21 @@ tf_kernel_library(
],
)
tf_kernel_library
(
name
=
"auto_shard_dataset_op"
,
srcs
=
[
"auto_shard_dataset_op.cc"
],
deps
=
[
"//tensorflow/core:core_cpu_internal"
,
"//tensorflow/core:dataset_ops_op_lib"
,
"//tensorflow/core:framework"
,
"//tensorflow/core:lib"
,
"//tensorflow/core:lib_internal"
,
"//tensorflow/core:protos_all_cc"
,
"//tensorflow/core/grappler/optimizers/data:auto_shard"
,
"//tensorflow/core/kernels/data:dataset_utils"
,
],
)
tf_kernel_library
(
name
=
"choose_fastest_branch_dataset_op"
,
srcs
=
[
"choose_fastest_branch_dataset_op.cc"
],
...
...
@@ -73,21 +84,6 @@ tf_kernel_library(
],
)
tf_kernel_library
(
name
=
"auto_shard_dataset_op"
,
srcs
=
[
"auto_shard_dataset_op.cc"
],
deps
=
[
"//tensorflow/core:core_cpu_internal"
,
"//tensorflow/core:dataset_ops_op_lib"
,
"//tensorflow/core:framework"
,
"//tensorflow/core:lib"
,
"//tensorflow/core:lib_internal"
,
"//tensorflow/core:protos_all_cc"
,
"//tensorflow/core/grappler/optimizers/data:auto_shard"
,
"//tensorflow/core/kernels/data:graph_rewrite_dataset"
,
],
)
tf_kernel_library
(
name
=
"group_by_reducer_dataset_op"
,
srcs
=
[
"group_by_reducer_dataset_op.cc"
],
...
...
@@ -251,7 +247,7 @@ tf_kernel_library(
"//tensorflow/core:lib_internal"
,
"//tensorflow/core:protos_all_cc"
,
"//tensorflow/core/grappler/optimizers/data:rebatch"
,
"//tensorflow/core/kernels/data:
graph_rewrite_dataset
"
,
"//tensorflow/core/kernels/data:
dataset_utils
"
,
],
)
...
...
tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.cc
浏览文件 @
06fb6333
...
...
@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace
tensorflow
{
namespace
data
{
...
...
@@ -24,17 +25,12 @@ constexpr char kOptimizerName[] = "tf_auto_shard";
class
AutoShardDatasetOp
:
public
UnaryDatasetOpKernel
{
public:
explicit
AutoShardDatasetOp
(
OpKernelConstruction
*
ctx
)
:
UnaryDatasetOpKernel
(
ctx
),
graph_def_version_
(
ctx
->
graph_def_version
())
{
OP_REQUIRES_OK
(
ctx
,
ctx
->
GetAttr
(
"output_types"
,
&
output_types_
));
OP_REQUIRES_OK
(
ctx
,
ctx
->
GetAttr
(
"output_shapes"
,
&
output_shapes_
));
}
:
UnaryDatasetOpKernel
(
ctx
)
{}
protected:
void
MakeDataset
(
OpKernelContext
*
ctx
,
DatasetBase
*
input
,
DatasetBase
**
output
)
override
{
int64
index
;
int64
num_workers
;
int64
index
,
num_workers
;
OP_REQUIRES_OK
(
ctx
,
ParseScalarArgument
(
ctx
,
"num_workers"
,
&
num_workers
));
OP_REQUIRES
(
ctx
,
num_workers
>
0
,
...
...
@@ -45,69 +41,39 @@ class AutoShardDatasetOp : public UnaryDatasetOpKernel {
errors
::
InvalidArgument
(
"index must be between 0 and "
,
num_workers
-
1
));
Dataset
*
dataset
=
new
Dataset
(
ctx
,
input
,
num_workers
,
index
,
output_types_
,
output_shapes_
);
const
Status
s
=
dataset
->
Optimize
(
ctx
);
if
(
s
.
ok
())
{
*
output
=
dataset
;
}
else
{
dataset
->
Unref
();
OP_REQUIRES_OK
(
ctx
,
s
);
}
auto
config_factory
=
[
num_workers
,
index
]()
{
return
CreateConfig
(
num_workers
,
index
);
};
// We only want to optimize functions for some particular datasets like
// FlatMapDataset, InterleaveDataset etc. So we disable generalized
// function optimization and explicitly handle function modifications
// for those datasets in the rewrite.
OP_REQUIRES_OK
(
ctx
,
RewriteDataset
(
ctx
,
input
,
std
::
move
(
config_factory
),
/*optimize_function_library=*/
false
,
output
));
}
private:
class
Dataset
:
public
GraphRewriteDataset
{
public:
Dataset
(
OpKernelContext
*
ctx
,
const
DatasetBase
*
input
,
const
int64
num_workers
,
const
int64
index
,
const
DataTypeVector
&
output_types
,
const
std
::
vector
<
PartialTensorShape
>&
output_shapes
)
:
GraphRewriteDataset
(
ctx
,
input
,
output_types
,
output_shapes
),
num_workers_
(
num_workers
),
index_
(
index
)
{}
string
DebugString
()
const
override
{
return
"AutoShardDatasetOp::Dataset"
;
}
private:
bool
ShouldOptimizeFunctions
()
override
{
// We only want to optimize functions for some particular datasets like
// FlatMapDataset, InterleaveDataset etc. So we disable generalized
// function optimization and explicitly handle function modifications
// for those datasets in the rewrite.
return
false
;
}
RewriterConfig
CreateGrapplerRewriteConfig
()
override
{
RewriterConfig
rewriter_config
;
rewriter_config
.
set_fail_on_optimizer_errors
(
true
);
rewriter_config
.
add_optimizers
(
kOptimizerName
);
rewriter_config
.
set_meta_optimizer_iterations
(
RewriterConfig_NumIterationsType_ONE
);
auto
custom_optimizer
=
rewriter_config
.
add_custom_optimizers
();
custom_optimizer
->
set_name
(
kOptimizerName
);
AttrValue
num_workers_attr
;
num_workers_attr
.
set_i
(
num_workers_
);
(
*
custom_optimizer
->
mutable_parameter_map
())[
"num_workers"
]
=
num_workers_attr
;
AttrValue
index_attr
;
index_attr
.
set_i
(
index_
);
(
*
custom_optimizer
->
mutable_parameter_map
())[
"index"
]
=
index_attr
;
return
rewriter_config
;
}
const
int64
num_workers_
;
const
int64
index_
;
};
const
int
graph_def_version_
;
DataTypeVector
output_types_
;
std
::
vector
<
PartialTensorShape
>
output_shapes_
;
static
RewriterConfig
CreateConfig
(
int64
num_workers
,
int64
index
)
{
RewriterConfig
rewriter_config
;
rewriter_config
.
set_fail_on_optimizer_errors
(
true
);
rewriter_config
.
add_optimizers
(
kOptimizerName
);
rewriter_config
.
set_meta_optimizer_iterations
(
RewriterConfig_NumIterationsType_ONE
);
auto
custom_optimizer
=
rewriter_config
.
add_custom_optimizers
();
custom_optimizer
->
set_name
(
kOptimizerName
);
AttrValue
num_workers_attr
;
num_workers_attr
.
set_i
(
num_workers
);
(
*
custom_optimizer
->
mutable_parameter_map
())[
"num_workers"
]
=
num_workers_attr
;
AttrValue
index_attr
;
index_attr
.
set_i
(
index
);
(
*
custom_optimizer
->
mutable_parameter_map
())[
"index"
]
=
index_attr
;
return
rewriter_config
;
}
};
REGISTER_KERNEL_BUILDER
(
Name
(
"ExperimentalAutoShardDataset"
).
Device
(
DEVICE_CPU
),
...
...
tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc
浏览文件 @
06fb6333
...
...
@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace
tensorflow
{
namespace
data
{
...
...
@@ -24,11 +25,7 @@ constexpr char kOptimizerName[] = "tf_data_rebatcher";
class
RebatchDatasetOp
:
public
UnaryDatasetOpKernel
{
public:
explicit
RebatchDatasetOp
(
OpKernelConstruction
*
ctx
)
:
UnaryDatasetOpKernel
(
ctx
),
graph_def_version_
(
ctx
->
graph_def_version
())
{
OP_REQUIRES_OK
(
ctx
,
ctx
->
GetAttr
(
"output_types"
,
&
output_types_
));
OP_REQUIRES_OK
(
ctx
,
ctx
->
GetAttr
(
"output_shapes"
,
&
output_shapes_
));
}
:
UnaryDatasetOpKernel
(
ctx
)
{}
protected:
void
MakeDataset
(
OpKernelContext
*
ctx
,
DatasetBase
*
input
,
...
...
@@ -39,58 +36,32 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel {
ctx
,
num_workers
>
0
,
errors
::
InvalidArgument
(
"num_workers must be greater than zero."
));
Dataset
*
dataset
=
new
Dataset
(
ctx
,
input
,
num_workers
,
output_types_
,
output_shapes_
);
Status
s
=
dataset
->
Optimize
(
ctx
);
if
(
s
.
ok
())
{
*
output
=
dataset
;
}
else
{
dataset
->
Unref
();
OP_REQUIRES_OK
(
ctx
,
s
);
}
auto
config_factory
=
[
num_workers
]()
{
return
CreateConfig
(
num_workers
);
};
// We only want to optimize functions for some particular datasets like
// FlatMapDataset, InterleaveDataset etc. So we disable generalized
// function optimization and explicitly handle function modifications
// for those datasets in the rewrite.
OP_REQUIRES_OK
(
ctx
,
RewriteDataset
(
ctx
,
input
,
std
::
move
(
config_factory
),
/*optimize_function_library=*/
false
,
output
));
}
private:
class
Dataset
:
public
GraphRewriteDataset
{
public:
Dataset
(
OpKernelContext
*
ctx
,
const
DatasetBase
*
input
,
const
int64
num_workers
,
const
DataTypeVector
&
output_types
,
const
std
::
vector
<
PartialTensorShape
>&
output_shapes
)
:
GraphRewriteDataset
(
ctx
,
input
,
output_types
,
output_shapes
),
num_workers_
(
num_workers
)
{}
string
DebugString
()
const
override
{
return
"RebatchDatasetOp::Dataset"
;
}
private:
bool
ShouldOptimizeFunctions
()
override
{
// We only want to optimize functions for some particular datasets like
// FlatMapDataset, InterleaveDataset etc. So we disable generalized
// function optimization and explicitly handle function modifications
// for those datasets in the rewrite.
return
false
;
}
RewriterConfig
CreateGrapplerRewriteConfig
()
override
{
RewriterConfig
rewriter_config
;
rewriter_config
.
set_fail_on_optimizer_errors
(
true
);
rewriter_config
.
add_optimizers
(
kOptimizerName
);
rewriter_config
.
set_meta_optimizer_iterations
(
RewriterConfig_NumIterationsType_ONE
);
auto
custom_optimizer
=
rewriter_config
.
add_custom_optimizers
();
custom_optimizer
->
set_name
(
kOptimizerName
);
AttrValue
num_workers_attr
;
num_workers_attr
.
set_i
(
num_workers_
);
(
*
custom_optimizer
->
mutable_parameter_map
())[
"num_workers"
]
=
num_workers_attr
;
return
rewriter_config
;
}
const
int64
num_workers_
;
};
const
int
graph_def_version_
;
DataTypeVector
output_types_
;
std
::
vector
<
PartialTensorShape
>
output_shapes_
;
static
RewriterConfig
CreateConfig
(
int64
num_workers
)
{
RewriterConfig
rewriter_config
;
rewriter_config
.
set_fail_on_optimizer_errors
(
true
);
rewriter_config
.
add_optimizers
(
kOptimizerName
);
rewriter_config
.
set_meta_optimizer_iterations
(
RewriterConfig_NumIterationsType_ONE
);
auto
custom_optimizer
=
rewriter_config
.
add_custom_optimizers
();
custom_optimizer
->
set_name
(
kOptimizerName
);
AttrValue
num_workers_attr
;
num_workers_attr
.
set_i
(
num_workers
);
(
*
custom_optimizer
->
mutable_parameter_map
())[
"num_workers"
]
=
num_workers_attr
;
return
rewriter_config
;
}
};
REGISTER_KERNEL_BUILDER
(
Name
(
"ExperimentalRebatchDataset"
).
Device
(
DEVICE_CPU
),
...
...
tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
浏览文件 @
06fb6333
...
...
@@ -123,7 +123,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK
(
ctx
,
ParseScalarArgument
(
ctx
,
"path"
,
&
path
));
GraphDef
graph_def
;
OP_REQUIRES_OK
(
ctx
,
AsGraphDef
(
ctx
,
input
,
&
graph_def
));
OP_REQUIRES_OK
(
ctx
,
AsGraphDef
(
ctx
,
input
,
SerializationContext
({}),
&
graph_def
));
// TODO(frankchn): Find a better way than SerializeToStringDeterministic()
// This is not deterministic across different builds of binaries right now.
...
...
tensorflow/core/kernels/data/graph_rewrite_dataset.cc
已删除
100644 → 0
浏览文件 @
a85bbeb7
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h"
#include <memory>
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace
tensorflow
{
namespace
data
{
GraphRewriteDataset
::~
GraphRewriteDataset
()
{
input_
->
Unref
();
if
(
optimized_input_
)
{
optimized_input_
->
Unref
();
}
}
Status
GraphRewriteDataset
::
Optimize
(
OpKernelContext
*
ctx
)
{
GraphDefBuilder
b
;
DatasetGraphDefBuilder
db
(
&
b
);
Node
*
input_node
=
nullptr
;
SerializationContext
::
Params
params
;
std
::
vector
<
std
::
pair
<
string
,
Tensor
>>
input_list
;
params
.
input_list
=
&
input_list
;
params
.
optimization_only
=
true
;
SerializationContext
serialization_ctx
(
params
);
TF_RETURN_IF_ERROR
(
db
.
AddInputDataset
(
&
serialization_ctx
,
input_
,
&
input_node
));
string
output_node
=
input_node
->
name
();
GraphDef
graph_def
;
TF_RETURN_IF_ERROR
(
b
.
ToGraphDef
(
&
graph_def
));
VLOG
(
3
)
<<
"Before optimization: "
<<
graph_def
.
DebugString
();
TF_RETURN_IF_ERROR
(
ApplyOptimizations
(
ctx
,
&
graph_def
,
&
output_node
));
VLOG
(
3
)
<<
"After optimization: "
<<
graph_def
.
DebugString
();
// Instantiate the optimized input pipeline by running the optimized graph
// using the optimized function library.
TF_RETURN_IF_ERROR
(
ctx
->
function_library
()
->
Clone
(
&
lib_def_
,
&
pflr_
,
&
flr_
,
true
));
// Create a FunctionHandleCache.
function_handle_cache_
=
absl
::
make_unique
<
FunctionHandleCache
>
(
flr_
);
// Some functions may have been modified without having their names
// changed (for example, nested dataset graphs from FlatMap or
// Interleave).
TF_RETURN_IF_ERROR
(
AddToFunctionLibrary
(
lib_def_
.
get
(),
graph_def
.
library
()));
Graph
graph
(
OpRegistry
::
Global
());
TF_RETURN_IF_ERROR
(
ImportGraphDef
({},
graph_def
,
&
graph
,
nullptr
));
std
::
vector
<
Tensor
>
outputs
;
GraphRunner
graph_runner
(
flr_
->
device
());
TF_RETURN_IF_ERROR
(
graph_runner
.
Run
(
&
graph
,
flr_
,
input_list
,
{
output_node
},
&
outputs
));
TF_RETURN_IF_ERROR
(
GetDatasetFromVariantTensor
(
outputs
[
0
],
&
optimized_input_
));
optimized_input_
->
Ref
();
return
Status
::
OK
();
}
Status
GraphRewriteDataset
::
AsGraphDefInternal
(
SerializationContext
*
ctx
,
DatasetGraphDefBuilder
*
b
,
Node
**
output
)
const
{
// We only serialize the optimized dataset to avoid re-running optimizations
// when the input pipeline is restored from a checkpoint.
TF_RETURN_IF_ERROR
(
b
->
AddInputDataset
(
ctx
,
optimized_input_
,
output
));
return
Status
::
OK
();
}
namespace
{
void
AddFakeSinks
(
FunctionDef
*
function_def
)
{
int
counter
=
0
;
for
(
const
auto
&
output
:
function_def
->
signature
().
output_arg
())
{
NodeDef
*
node
=
function_def
->
add_node_def
();
tensorflow
::
grappler
::
function_utils
::
SetUniqueFunctionNodeName
(
strings
::
StrCat
(
"FakeSink"
,
counter
++
),
function_def
,
node
);
node
->
set_op
(
"Identity"
);
node
->
add_input
(
function_def
->
ret
().
at
(
output
.
name
()));
(
*
node
->
mutable_attr
())[
"T"
].
set_type
(
output
.
type
());
(
*
function_def
->
mutable_ret
())[
output
.
name
()]
=
strings
::
StrCat
(
node
->
name
(),
":output:0"
);
}
}
void
RemoveFakeSinks
(
FunctionDef
*
function_def
)
{
// Map from identity node names to their input tensor strings
std
::
map
<
string
,
string
>
identity_map
;
for
(
const
auto
&
node
:
function_def
->
node_def
())
{
if
(
node
.
op
()
==
"Identity"
&&
node
.
input_size
()
==
1
)
{
identity_map
[
node
.
name
()]
=
node
.
input
(
0
);
}
}
for
(
const
auto
&
output_arg
:
function_def
->
signature
().
output_arg
())
{
const
string
&
tensor
=
function_def
->
ret
().
at
(
output_arg
.
name
());
const
string
&
output_node
=
tensor
.
substr
(
0
,
tensor
.
find
(
':'
));
if
(
identity_map
.
find
(
output_node
)
!=
identity_map
.
end
())
{
(
*
function_def
->
mutable_ret
())[
output_arg
.
name
()]
=
identity_map
.
at
(
output_node
);
}
}
}
}
// anonymous namespace
Status
GraphRewriteDataset
::
ApplyOptimizations
(
OpKernelContext
*
ctx
,
GraphDef
*
graph_def
,
string
*
output_node
)
{
// Add an identity node as the fetch node, otherwise we might get 'placeholder
// is both fed and fetched' errors in some cases when using input list with
// placeholder dataset nodes.
NodeDef
*
node
=
graph_def
->
mutable_node
()
->
Add
();
tensorflow
::
grappler
::
graph_utils
::
SetUniqueGraphNodeName
(
"Sink"
,
graph_def
,
node
);
node
->
set_op
(
"Identity"
);
node
->
add_input
(
*
output_node
);
(
*
node
->
mutable_attr
())[
"T"
].
set_type
(
DT_VARIANT
);
*
output_node
=
node
->
name
();
// Add fake sink node to graph and functions to allow rewriting the actual
// sink nodes.
//
// TODO(b/118820916): When MetaOptimizer adds provisions for function retvals
// to be optimizable, we will no longer need this.
for
(
auto
&
function_def
:
*
graph_def
->
mutable_library
()
->
mutable_function
())
{
AddFakeSinks
(
&
function_def
);
}
// Create metagraph.
MetaGraphDef
meta_graph_def
;
(
*
meta_graph_def
.
mutable_graph_def
())
=
*
graph_def
;
// Grappler determines fetch ops from collection 'train_op'.
CollectionDef
collection_def
;
auto
node_list
=
collection_def
.
mutable_node_list
();
node_list
->
add_value
(
*
output_node
);
(
*
meta_graph_def
.
mutable_collection_def
())[
"train_op"
]
=
collection_def
;
// Create Grappler item.
tensorflow
::
grappler
::
ItemConfig
item_config
;
item_config
.
apply_optimizations
=
true
;
std
::
unique_ptr
<
tensorflow
::
grappler
::
GrapplerItem
>
grappler_item
=
tensorflow
::
grappler
::
GrapplerItemFromMetaGraphDef
(
"graph"
,
meta_graph_def
,
item_config
);
grappler_item
->
optimization_options
().
optimize_function_library
=
ShouldOptimizeFunctions
();
std
::
unordered_map
<
string
,
tensorflow
::
DeviceProperties
>
device_map
;
tensorflow
::
grappler
::
VirtualCluster
cluster
(
device_map
);
// Run data optimizer using grappler's meta optimizer.
tensorflow
::
ConfigProto
config
;
*
config
.
mutable_graph_options
()
->
mutable_rewrite_options
()
=
CreateGrapplerRewriteConfig
();
TF_RETURN_IF_ERROR
(
tensorflow
::
grappler
::
RunMetaOptimizer
(
*
grappler_item
,
config
,
ctx
->
device
(),
&
cluster
,
graph_def
));
// Remove fake sinks after optimizations are done.
//
// TODO(b/118820916): When MetaOptimizer adds provisions for function retvals
// to be optimizable, we will no longer need this.
for
(
auto
&
function_def
:
*
graph_def
->
mutable_library
()
->
mutable_function
())
{
RemoveFakeSinks
(
&
function_def
);
}
return
Status
::
OK
();
}
class
GraphRewriteDataset
::
Iterator
:
public
DatasetIterator
<
GraphRewriteDataset
>
{
public:
explicit
Iterator
(
const
Params
&
params
)
:
DatasetIterator
<
GraphRewriteDataset
>
(
params
)
{}
Status
Initialize
(
IteratorContext
*
ctx
)
override
{
IteratorContext
::
Params
params
(
ctx
);
params
.
flr
=
dataset
()
->
flr_
;
params
.
function_handle_cache
=
dataset
()
->
function_handle_cache_
.
get
();
return
dataset
()
->
optimized_input_
->
MakeIterator
(
IteratorContext
(
std
::
move
(
params
)),
prefix
(),
&
input_impl_
);
}
Status
GetNextInternal
(
IteratorContext
*
ctx
,
std
::
vector
<
Tensor
>*
out_tensors
,
bool
*
end_of_sequence
)
override
{
IteratorContext
::
Params
params
(
ctx
);
params
.
flr
=
dataset
()
->
flr_
;
params
.
function_handle_cache
=
dataset
()
->
function_handle_cache_
.
get
();
return
input_impl_
->
GetNext
(
IteratorContext
(
std
::
move
(
params
)),
out_tensors
,
end_of_sequence
);
}
protected:
std
::
shared_ptr
<
model
::
Node
>
CreateNode
(
IteratorContext
*
ctx
,
model
::
Node
::
Args
args
)
const
override
{
return
model
::
MakeKnownRatioNode
(
std
::
move
(
args
),
/*ratio=*/
1
);
}
Status
SaveInternal
(
IteratorStateWriter
*
writer
)
override
{
TF_RETURN_IF_ERROR
(
SaveInput
(
writer
,
input_impl_
));
return
Status
::
OK
();
}
Status
RestoreInternal
(
IteratorContext
*
ctx
,
IteratorStateReader
*
reader
)
override
{
TF_RETURN_IF_ERROR
(
RestoreInput
(
ctx
,
reader
,
input_impl_
));
return
Status
::
OK
();
}
private:
std
::
unique_ptr
<
IteratorBase
>
input_impl_
;
};
std
::
unique_ptr
<
IteratorBase
>
GraphRewriteDataset
::
MakeIteratorInternal
(
const
string
&
prefix
)
const
{
// We do not add a token for this dataset to the prefix. The
// prefix is used to identify checkpoint elements and since this
// dataset is excluded from the checkpoint, adding a token
// here would result in invalid checkpoint identifiers.
return
absl
::
make_unique
<
Iterator
>
(
Iterator
::
Params
{
this
,
prefix
});
}
}
// namespace data
}
// namespace tensorflow
tensorflow/core/kernels/data/graph_rewrite_dataset.h
已删除
100644 → 0
浏览文件 @
a85bbeb7
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_GRAPH_REWRITE_DATASET_H_
#define TENSORFLOW_CORE_KERNELS_DATA_GRAPH_REWRITE_DATASET_H_
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function_handle_cache.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
namespace
tensorflow
{
namespace
data
{
class
GraphRewriteDataset
:
public
DatasetBase
{
public:
GraphRewriteDataset
(
OpKernelContext
*
ctx
,
const
DatasetBase
*
input
,
const
DataTypeVector
&
output_types
,
const
std
::
vector
<
PartialTensorShape
>&
output_shapes
)
:
DatasetBase
(
DatasetContext
(
ctx
)),
optimized_input_
(
nullptr
),
input_
(
input
),
output_types_
(
output_types
),
output_shapes_
(
output_shapes
)
{
input_
->
Ref
();
}
~
GraphRewriteDataset
()
override
;
// Runs Grappler to transform the input dataset into optimized_input_
// dataset.
Status
Optimize
(
OpKernelContext
*
ctx
);
std
::
unique_ptr
<
IteratorBase
>
MakeIteratorInternal
(
const
string
&
prefix
)
const
override
;
const
DataTypeVector
&
output_dtypes
()
const
override
{
return
output_types_
;
}
const
std
::
vector
<
PartialTensorShape
>&
output_shapes
()
const
override
{
return
output_shapes_
;
}
int64
Cardinality
()
const
override
{
return
input_
->
Cardinality
();
}
protected:
Status
AsGraphDefInternal
(
SerializationContext
*
ctx
,
DatasetGraphDefBuilder
*
b
,
Node
**
output
)
const
override
;
private:
class
Iterator
;
// Create a Grappler RewriteConfig proto that defines the list of
// optimizations to be run by the Grappler Meta Optimizer.
virtual
RewriterConfig
CreateGrapplerRewriteConfig
()
=
0
;
// Option specifying whether we want to optimize the function library as well.
virtual
bool
ShouldOptimizeFunctions
()
{
return
true
;
}
Status
ApplyOptimizations
(
OpKernelContext
*
ctx
,
GraphDef
*
graph_def
,
string
*
output_node
);
DatasetBase
*
optimized_input_
;
FunctionLibraryRuntime
*
flr_
=
nullptr
;
std
::
unique_ptr
<
ProcessFunctionLibraryRuntime
>
pflr_
=
nullptr
;
std
::
unique_ptr
<
FunctionLibraryDefinition
>
lib_def_
=
nullptr
;
std
::
unique_ptr
<
FunctionHandleCache
>
function_handle_cache_
=
nullptr
;
const
DatasetBase
*
input_
;
const
DataTypeVector
output_types_
;
const
std
::
vector
<
PartialTensorShape
>
output_shapes_
;
};
}
// namespace data
}
// namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_GRAPH_REWRITE_DATASET_H_
tensorflow/core/kernels/data/optimize_dataset_op.cc
浏览文件 @
06fb6333
...
...
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/
graph_rewrite_dataset
.h"
#include "tensorflow/core/kernels/data/
dataset_utils
.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
...
...
@@ -32,12 +32,9 @@ constexpr char kOptimizerName[] = "tf_data_meta_optimizer";
class
OptimizeDatasetOp
:
public
UnaryDatasetOpKernel
{
public:
explicit
OptimizeDatasetOp
(
OpKernelConstruction
*
ctx
)
:
UnaryDatasetOpKernel
(
ctx
),
graph_def_version_
(
ctx
->
graph_def_version
())
{
OP_REQUIRES_OK
(
ctx
,
ctx
->
GetAttr
(
"output_types"
,
&
output_types_
));
OP_REQUIRES_OK
(
ctx
,
ctx
->
GetAttr
(
"output_shapes"
,
&
output_shapes_
));
OP_REQUIRES_OK
(
ctx
,
ctx
->
GetAttr
(
"optimization_configs"
,
&
optimizer_configs_
));
:
UnaryDatasetOpKernel
(
ctx
)
{
OP_REQUIRES_OK
(
ctx
,
ctx
->
GetAttr
(
"optimization_configs"
,
&
optimization_configs_
));
}
protected:
...
...
@@ -46,62 +43,41 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
std
::
vector
<
string
>
optimizations
;
OP_REQUIRES_OK
(
ctx
,
ParseVectorArgument
<
string
>
(
ctx
,
"optimizations"
,
&
optimizations
));
Dataset
*
dataset
=
new
Dataset
(
ctx
,
input
,
optimizations
,
output_types_
,
output_shapes_
,
optimizer_configs_
);
Status
s
=
dataset
->
Optimize
(
ctx
);
if
(
s
.
ok
())
{
*
output
=
dataset
;
}
else
{
dataset
->
Unref
();
OP_REQUIRES_OK
(
ctx
,
s
);
}
auto
config_factory
=
[
this
,
&
optimizations
]()
{
return
CreateConfig
(
optimizations
,
optimization_configs_
);
};
OP_REQUIRES_OK
(
ctx
,
RewriteDataset
(
ctx
,
input
,
std
::
move
(
config_factory
),
/*optimize_function_library=*/
true
,
output
));
}
private:
class
Dataset
:
public
GraphRewriteDataset
{
public:
Dataset
(
OpKernelContext
*
ctx
,
const
DatasetBase
*
input
,
const
std
::
vector
<
string
>&
optimizations
,
const
DataTypeVector
&
output_types
,
const
std
::
vector
<
PartialTensorShape
>&
output_shapes
,
const
std
::
vector
<
string
>&
optimizer_configs
)
:
GraphRewriteDataset
(
ctx
,
input
,
output_types
,
output_shapes
),
optimizations_
(
optimizations
),
optimizer_configs_
(
optimizer_configs
)
{}
string
DebugString
()
const
override
{
return
"OptimizeDatasetOp::Dataset"
;
}
private:
RewriterConfig
CreateGrapplerRewriteConfig
()
override
{
RewriterConfig
rewriter_config
;
rewriter_config
.
add_optimizers
(
kOptimizerName
);
rewriter_config
.
set_meta_optimizer_iterations
(
RewriterConfig_NumIterationsType_ONE
);
auto
custom_optimizer
=
rewriter_config
.
add_custom_optimizers
();
custom_optimizer
->
set_name
(
kOptimizerName
);
auto
*
custom_optimizations_list
=
(
*
custom_optimizer
->
mutable_parameter_map
())[
"optimizers"
]
.
mutable_list
();
for
(
const
auto
&
opt
:
optimizations_
)
{
custom_optimizations_list
->
add_s
(
opt
);
}
auto
*
config_list
=
(
*
custom_optimizer
->
mutable_parameter_map
())[
"optimizer_configs"
]
.
mutable_list
();
for
(
const
auto
&
config
:
optimizer_configs_
)
{
config_list
->
add_s
(
config
);
}
return
rewriter_config
;
static
RewriterConfig
CreateConfig
(
std
::
vector
<
string
>
optimizations
,
std
::
vector
<
string
>
optimizations_configs
)
{
RewriterConfig
rewriter_config
;
rewriter_config
.
add_optimizers
(
kOptimizerName
);
rewriter_config
.
set_meta_optimizer_iterations
(
RewriterConfig_NumIterationsType_ONE
);
auto
custom_optimizer
=
rewriter_config
.
add_custom_optimizers
();
custom_optimizer
->
set_name
(
kOptimizerName
);
auto
*
custom_optimizations_list
=
(
*
custom_optimizer
->
mutable_parameter_map
())[
"optimizers"
]
.
mutable_list
();
for
(
const
auto
&
opt
:
optimizations
)
{
custom_optimizations_list
->
add_s
(
opt
);
}
auto
*
config_list
=
(
*
custom_optimizer
->
mutable_parameter_map
())[
"optimizer_configs"
]
.
mutable_list
();
for
(
const
auto
&
config
:
optimizations_configs
)
{
config_list
->
add_s
(
config
);
}
return
rewriter_config
;
}
const
std
::
vector
<
string
>
optimizations_
;
const
std
::
vector
<
string
>
optimizer_configs_
;
};
const
int
graph_def_version_
;
DataTypeVector
output_types_
;
std
::
vector
<
PartialTensorShape
>
output_shapes_
;
std
::
vector
<
string
>
optimizer_configs_
;
std
::
vector
<
string
>
optimization_configs_
;
};
REGISTER_KERNEL_BUILDER
(
Name
(
"OptimizeDataset"
).
Device
(
DEVICE_CPU
),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录