Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
9dd35dc5
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
9dd35dc5
编写于
7月 16, 2019
作者:
A
Andrew Audibert
提交者:
TensorFlower Gardener
7月 16, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Automated rollback of commit
0d2891d0
PiperOrigin-RevId: 258445832
上级
41b0cfa4
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
595 addition
and
0 deletion
+595
-0
tensorflow/core/common_runtime/data/BUILD
tensorflow/core/common_runtime/data/BUILD
+29
-0
tensorflow/core/common_runtime/data/standalone.cc
tensorflow/core/common_runtime/data/standalone.cc
+139
-0
tensorflow/core/common_runtime/data/standalone.h
tensorflow/core/common_runtime/data/standalone.h
+120
-0
tensorflow/core/common_runtime/data/standalone_test.cc
tensorflow/core/common_runtime/data/standalone_test.cc
+307
-0
未找到文件。
tensorflow/core/common_runtime/data/BUILD
0 → 100644
浏览文件 @
9dd35dc5
package
(
licenses
=
[
"notice"
],
# Apache 2.0
)
load
(
"//tensorflow:tensorflow.bzl"
,
"tf_cc_test"
)
load
(
"//tensorflow/core:platform/default/build_config.bzl"
,
"tf_protos_all"
)
cc_library
(
name
=
"standalone"
,
srcs
=
[
"standalone.cc"
],
hdrs
=
[
"standalone.h"
],
deps
=
[
"//tensorflow/core:core_cpu_internal"
,
"//tensorflow/core:framework"
,
"//tensorflow/core:lib"
,
"//tensorflow/core:session_options"
,
],
)
tf_cc_test
(
name
=
"standalone_test"
,
srcs
=
[
"standalone_test.cc"
],
deps
=
[
":standalone"
,
"//tensorflow/core:all_kernels"
,
"//tensorflow/core:test"
,
"//tensorflow/core:test_main"
,
]
+
tf_protos_all
(),
)
tensorflow/core/common_runtime/data/standalone.cc
0 → 100644
浏览文件 @
9dd35dc5
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/data/standalone.h"
#include <memory>
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/ptr_util.h"
namespace
tensorflow
{
namespace
data
{
namespace
standalone
{
Status
Iterator
::
GetNext
(
std
::
vector
<
Tensor
>*
outputs
,
bool
*
end_of_input
)
{
return
iterator_
->
GetNext
(
ctx_
.
get
(),
outputs
,
end_of_input
);
}
Iterator
::
Iterator
(
IteratorBase
*
iterator
,
IteratorContext
*
ctx
)
:
iterator_
(
iterator
),
ctx_
(
ctx
)
{}
Status
Dataset
::
FromGraph
(
Params
params
,
const
GraphDef
&
graph_def
,
std
::
unique_ptr
<
Dataset
>*
result
)
{
Graph
graph
(
OpRegistry
::
Global
());
TF_RETURN_IF_ERROR
(
ImportGraphDef
({},
graph_def
,
&
graph
,
nullptr
));
// Instantiate enough of the TensorFlow runtime to run `graph` on a single CPU
// device.
std
::
unique_ptr
<
DeviceMgr
>
device_mgr
=
MakeUnique
<
DeviceMgr
>
(
DeviceFactory
::
NewDevice
(
"CPU"
,
params
.
session_options
,
"/job:localhost/replica:0/task:0"
));
Device
*
device
=
device_mgr
->
ListDevices
()[
0
];
// Clone the `FunctionLibraryDefinition` to extend its lifetime extends beyond
// the lifetime of `graph`.
std
::
unique_ptr
<
FunctionLibraryDefinition
>
flib_def
=
MakeUnique
<
FunctionLibraryDefinition
>
(
graph
.
flib_def
());
std
::
unique_ptr
<
ProcessFunctionLibraryRuntime
>
pflr
=
MakeUnique
<
ProcessFunctionLibraryRuntime
>
(
device_mgr
.
get
(),
Env
::
Default
(),
TF_GRAPH_DEF_VERSION
,
flib_def
.
get
(),
OptimizerOptions
{},
nullptr
/* parent */
);
string
fetch_node
=
""
;
for
(
auto
node
:
graph_def
.
node
())
{
if
(
node
.
op
()
==
"_Retval"
)
{
fetch_node
=
node
.
input
(
0
);
}
}
if
(
fetch_node
.
empty
())
{
return
errors
::
NotFound
(
"Failed to find a _Retval op in the given dataset"
);
}
// Run graph up to `output_node` and extract the `DatasetBase` stored in the
// DT_VARIANT output tensor.
data
::
DatasetBase
*
dataset
;
{
std
::
vector
<
Tensor
>
outputs
;
GraphRunner
graph_runner
(
device
);
TF_RETURN_IF_ERROR
(
graph_runner
.
Run
(
&
graph
,
pflr
->
GetFLR
(
"/device:CPU:0"
),
{},
{
fetch_node
},
&
outputs
));
TF_RETURN_IF_ERROR
(
GetDatasetFromVariantTensor
(
outputs
[
0
],
&
dataset
));
// NOTE(mrry): The dataset is currently owned by `outputs[0]`, so acquire an
// additional reference.
dataset
->
Ref
();
}
std
::
unique_ptr
<
thread
::
ThreadPool
>
pool
(
NewThreadPoolFromSessionOptions
(
params
.
session_options
));
*
result
=
WrapUnique
(
new
Dataset
(
dataset
,
device_mgr
.
release
(),
pflr
.
release
(),
flib_def
.
release
(),
pool
.
release
()));
return
Status
::
OK
();
}
// static
Status
Dataset
::
MakeIterator
(
std
::
unique_ptr
<
Iterator
>*
result
)
{
// Create an `IteratorContext`, which bundles together the necessary runtime
// support to create and get elements from an iterator.
std
::
unique_ptr
<
IteratorContext
>
ctx
;
{
// NOTE(mrry): In the current API, an `IteratorContext` is always initially
// created from an `OpKernelContext*`, so we need to create a fake
// `OpKernelContext` with the appropriate subset of parameters.
OpKernelContext
::
Params
op_params
;
op_params
.
function_library
=
pflr_
->
GetFLR
(
"/device:CPU:0"
);
op_params
.
device
=
device_mgr_
->
ListDevices
()[
0
];
op_params
.
runner
=
&
runner_
;
OpKernelContext
op_ctx
(
&
op_params
,
0
);
IteratorContext
::
Params
params
(
&
op_ctx
);
params
.
function_handle_cache
=
function_handle_cache_
.
get
();
ctx
=
MakeUnique
<
IteratorContext
>
(
std
::
move
(
params
));
}
// Create the iterator from the dataset.
std
::
unique_ptr
<
IteratorBase
>
iterator
;
TF_RETURN_IF_ERROR
(
dataset_
->
MakeIterator
(
ctx
.
get
(),
"iterator"
,
&
iterator
));
*
result
=
WrapUnique
(
new
Iterator
(
iterator
.
release
(),
ctx
.
release
()));
return
Status
::
OK
();
}
Dataset
::
Dataset
(
DatasetBase
*
dataset
,
DeviceMgr
*
device_mgr
,
ProcessFunctionLibraryRuntime
*
pflr
,
FunctionLibraryDefinition
*
flib_def
,
thread
::
ThreadPool
*
pool
)
:
dataset_
(
dataset
),
device_mgr_
(
device_mgr
),
flib_def_
(
flib_def
),
pflr_
(
pflr
),
pool_
(
pool
)
{
runner_
=
[
this
](
std
::
function
<
void
()
>
c
)
{
pool_
->
Schedule
(
std
::
move
(
c
));
};
function_handle_cache_
=
MakeUnique
<
FunctionHandleCache
>
(
pflr_
->
GetFLR
(
"/device:CPU:0"
));
}
Dataset
::~
Dataset
()
{
dataset_
->
Unref
();
}
}
// namespace standalone
}
// namespace data
}
// namespace tensorflow
tensorflow/core/common_runtime/data/standalone.h
0 → 100644
浏览文件 @
9dd35dc5
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DATA_STANDALONE_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_DATA_STANDALONE_H_
#include <memory>
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function_handle_cache.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/public/session_options.h"
namespace
tensorflow
{
namespace
data
{
namespace
standalone
{
// The purpose of the API in this file is to facilitate standalone execution of
// a tf.data input pipeline graph.
//
// The API exposes two abstractions -- a `Dataset` and an `Iterator` -- which
// encapsulate TensorFlow runtime.
//
// The `Dataset` abstraction represents an input pipeline as a collection
// of data sources and a logical plan of transformations that operate over the
// data.
//
// The `Iterator` abstraction represents an execution of an input pipeline that
// can be used to enumerate its elements.
//
// Example usage:
//
// // Create a `Dataset` by running the `graph_def` graph.
// tensorflow::data:standalone::Dataset::Params params;
// std::unique_ptr<tensorflow::data::standalone::Dataset> dataset;
// Status s = tensorflow::data::standalone::Dataset::FromGraph(
// params, graph_def, &dataset);
// if (!s.ok()) { /* error handling */ }
//
// std::unique_ptr<tensorflow::data::standalone::Iterator> iterator;
// s = dataset->MakeIterator(&iterator);
// if (!s.ok()) { /* error handling */ }
//
// bool end_of_input = false;
// while (!end_of_input) {
// std::vector<tensorflow::Tensor> outputs;
// s = iterator->GetNext(&outputs, &end_of_input);
// if (!s.ok()) { /* error handling */ }
// if (!end_of_input) { /* output handling */ }
// }
class
Dataset
;
// Represents an execution of an input pipeline that can be used to enumerate
// its elements.
class
Iterator
{
public:
// Returns the next element of the input pipeline (if there is one) and an
// indication of whether the end of the input pipeline has been reached.
Status
GetNext
(
std
::
vector
<
Tensor
>*
outputs
,
bool
*
end_of_input
);
private:
friend
class
Dataset
;
Iterator
(
IteratorBase
*
iterator
,
IteratorContext
*
ctx
);
std
::
unique_ptr
<
IteratorBase
>
iterator_
;
std
::
unique_ptr
<
IteratorContext
>
ctx_
;
};
// Represents an input pipeline as a collection of data sources and a logical
// plan of transformations that operate over the data.
class
Dataset
{
public:
// Parameters for `Dataset` creation (e.g. TensorFlow runtime configuration).
struct
Params
{
SessionOptions
session_options
;
};
// Creates a new `Dataset` instance by running the given dataset graph.
static
Status
FromGraph
(
Params
params
,
const
GraphDef
&
graph_def
,
std
::
unique_ptr
<
Dataset
>*
result
);
~
Dataset
();
// Creates an iterator for this dataset.
Status
MakeIterator
(
std
::
unique_ptr
<
Iterator
>*
result
);
private:
Dataset
(
DatasetBase
*
dataset
,
DeviceMgr
*
device_mgr
,
ProcessFunctionLibraryRuntime
*
pflr
,
FunctionLibraryDefinition
*
flib_def
,
thread
::
ThreadPool
*
pool
);
DatasetBase
*
dataset_
;
// owned
std
::
unique_ptr
<
DeviceMgr
>
device_mgr_
;
std
::
unique_ptr
<
FunctionLibraryDefinition
>
flib_def_
;
std
::
unique_ptr
<
ProcessFunctionLibraryRuntime
>
pflr_
;
std
::
unique_ptr
<
thread
::
ThreadPool
>
pool_
;
std
::
unique_ptr
<
FunctionHandleCache
>
function_handle_cache_
;
std
::
function
<
void
(
std
::
function
<
void
()
>
)
>
runner_
;
};
}
// namespace standalone
}
// namespace data
}
// namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DATA_STANDALONE_H_
tensorflow/core/common_runtime/data/standalone_test.cc
0 → 100644
浏览文件 @
9dd35dc5
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/data/standalone.h"
#include <memory>
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
tensorflow
{
namespace
data
{
namespace
standalone
{
namespace
{
constexpr
const
char
*
const
kRangeGraphProto
=
R"proto(
node {
name: "Const/_0"
op: "Const"
attr {
key: "dtype"
value { type: DT_INT64 }
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT64
tensor_shape {}
int64_val: 0
}
}
}
}
node {
name: "Const/_1"
op: "Const"
attr {
key: "dtype"
value { type: DT_INT64 }
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT64
tensor_shape {}
int64_val: 10
}
}
}
}
node {
name: "Const/_2"
op: "Const"
attr {
key: "dtype"
value { type: DT_INT64 }
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT64
tensor_shape {}
int64_val: 1
}
}
}
}
node {
name: "RangeDataset/_3"
op: "RangeDataset"
input: "Const/_0"
input: "Const/_1"
input: "Const/_2"
attr {
key: "output_shapes"
value { list { shape {} } }
}
attr {
key: "output_types"
value { list { type: DT_INT64 } }
}
}
node {
name: "dataset"
op: "_Retval"
input: "RangeDataset/_3"
attr {
key: "T"
value { type: DT_VARIANT }
}
attr {
key: "index"
value { i: 0 }
}
}
library {}
versions { producer: 96 }
)proto"
;
// range(10).map(lambda x: x*x)
constexpr
const
char
*
const
kMapGraphProto
=
R"proto(
node {
name: "Const/_0"
op: "Const"
attr {
key: "dtype"
value { type: DT_INT64 }
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT64
tensor_shape {}
int64_val: 0
}
}
}
}
node {
name: "Const/_1"
op: "Const"
attr {
key: "dtype"
value { type: DT_INT64 }
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT64
tensor_shape {}
int64_val: 10
}
}
}
}
node {
name: "Const/_2"
op: "Const"
attr {
key: "dtype"
value { type: DT_INT64 }
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT64
tensor_shape {}
int64_val: 1
}
}
}
}
node {
name: "RangeDataset/_3"
op: "RangeDataset"
input: "Const/_0"
input: "Const/_1"
input: "Const/_2"
attr {
key: "output_shapes"
value { list { shape {} } }
}
attr {
key: "output_types"
value { list { type: DT_INT64 } }
}
}
node {
name: "MapDataset/_4"
op: "MapDataset"
input: "RangeDataset/_3"
attr {
key: "Targuments"
value { list {} }
}
attr {
key: "f"
value { func { name: "__inference_Dataset_map_<lambda>_67" } }
}
attr {
key: "output_shapes"
value { list { shape {} } }
}
attr {
key: "output_types"
value { list { type: DT_INT64 } }
}
attr {
key: "preserve_cardinality"
value { b: false }
}
attr {
key: "use_inter_op_parallelism"
value { b: true }
}
}
node {
name: "dataset"
op: "_Retval"
input: "MapDataset/_4"
attr {
key: "T"
value { type: DT_VARIANT }
}
attr {
key: "index"
value { i: 0 }
}
}
library {
function {
signature {
name: "__inference_Dataset_map_<lambda>_67"
input_arg { name: "args_0" type: DT_INT64 }
output_arg { name: "identity" type: DT_INT64 }
}
node_def {
name: "mul"
op: "Mul"
input: "args_0"
input: "args_0"
attr {
key: "T"
value { type: DT_INT64 }
}
}
node_def {
name: "Identity"
op: "Identity"
input: "mul:z:0"
attr {
key: "T"
value { type: DT_INT64 }
}
}
ret { key: "identity" value: "Identity:output:0" }
arg_attr {
key: 0
value {
attr {
key: "_user_specified_name"
value { s: "args_0" }
}
}
}
}
}
versions { producer: 96 min_consumer: 12 }
)proto"
;
TEST
(
Scalar
,
Standalone
)
{
struct
TestCase
{
string
graph_string
;
std
::
vector
<
int64
>
expected_outputs
;
};
auto
test_cases
=
{
TestCase
{
kRangeGraphProto
,
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
}},
TestCase
{
kMapGraphProto
,
{
0
,
1
,
4
,
9
,
16
,
25
,
36
,
49
,
64
,
81
}},
};
for
(
auto
test_case
:
test_cases
)
{
GraphDef
graph_def
;
protobuf
::
TextFormat
::
ParseFromString
(
test_case
.
graph_string
,
&
graph_def
);
std
::
unique_ptr
<
Dataset
>
dataset
;
auto
s
=
Dataset
::
FromGraph
({},
graph_def
,
&
dataset
);
TF_EXPECT_OK
(
s
);
std
::
unique_ptr
<
Iterator
>
iterator
;
s
=
dataset
->
MakeIterator
(
&
iterator
);
TF_EXPECT_OK
(
s
);
bool
end_of_input
=
false
;
for
(
int
num_outputs
=
0
;
!
end_of_input
;
++
num_outputs
)
{
std
::
vector
<
tensorflow
::
Tensor
>
outputs
;
s
=
iterator
->
GetNext
(
&
outputs
,
&
end_of_input
);
TF_EXPECT_OK
(
s
);
if
(
!
end_of_input
)
{
EXPECT_EQ
(
outputs
[
0
].
scalar
<
int64
>
()(),
test_case
.
expected_outputs
[
num_outputs
]);
}
else
{
EXPECT_EQ
(
test_case
.
expected_outputs
.
size
(),
num_outputs
);
}
}
}
}
}
// namespace
}
// namespace standalone
}
// namespace data
}
// namespace tensorflow
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录