Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wux_labs
Tensorflow
提交
c9bdd393
T
Tensorflow
项目概览
wux_labs
/
Tensorflow
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
c9bdd393
编写于
10月 03, 2018
作者:
D
Derek Murray
提交者:
TensorFlower Gardener
10月 03, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[tf.data] Switch background threads to use `BackgroundWorker`.
PiperOrigin-RevId: 215579950
上级
6b0d1ec9
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
46 addition
and
37 deletion
+46
-37
tensorflow/core/kernels/data/iterator_ops.cc
tensorflow/core/kernels/data/iterator_ops.cc
+0
-4
tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+6
-4
tensorflow/core/kernels/data/model_dataset_op.cc
tensorflow/core/kernels/data/model_dataset_op.cc
+6
-4
tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
...rflow/core/kernels/data/parallel_interleave_dataset_op.cc
+16
-11
tensorflow/core/kernels/data/parallel_map_iterator.cc
tensorflow/core/kernels/data/parallel_map_iterator.cc
+6
-4
tensorflow/core/kernels/data/prefetch_dataset_op.cc
tensorflow/core/kernels/data/prefetch_dataset_op.cc
+6
-4
tensorflow/core/kernels/data/writer_ops.cc
tensorflow/core/kernels/data/writer_ops.cc
+6
-6
未找到文件。
tensorflow/core/kernels/data/iterator_ops.cc
浏览文件 @
c9bdd393
...
@@ -16,10 +16,8 @@ limitations under the License.
...
@@ -16,10 +16,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/framework/iterator.pb.h"
#include "tensorflow/core/framework/iterator.pb.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/framework/variant_op_registry.h"
...
@@ -27,13 +25,11 @@ limitations under the License.
...
@@ -27,13 +25,11 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/optional_ops.h"
#include "tensorflow/core/kernels/data/optional_ops.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/public/session_options.h"
namespace
tensorflow
{
namespace
tensorflow
{
namespace
data
{
namespace
data
{
...
...
tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
浏览文件 @
c9bdd393
...
@@ -29,6 +29,7 @@ limitations under the License.
...
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/util/ptr_util.h"
namespace
tensorflow
{
namespace
tensorflow
{
namespace
data
{
namespace
data
{
...
@@ -405,9 +406,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
...
@@ -405,9 +406,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED
(
*
mu_
)
{
EXCLUSIVE_LOCKS_REQUIRED
(
*
mu_
)
{
if
(
!
runner_thread_
)
{
if
(
!
runner_thread_
)
{
std
::
shared_ptr
<
IteratorContext
>
ctx_copy
(
new
IteratorContext
(
*
ctx
));
std
::
shared_ptr
<
IteratorContext
>
ctx_copy
(
new
IteratorContext
(
*
ctx
));
runner_thread_
.
reset
(
ctx
->
env
()
->
StartThread
(
runner_thread_
=
{},
"runner_thread"
,
MakeUnique
<
BackgroundWorker
>
(
ctx
->
env
(),
"runner_thread"
);
std
::
bind
(
&
Iterator
::
RunnerThread
,
this
,
ctx_copy
)));
runner_thread_
->
Schedule
(
std
::
bind
(
&
Iterator
::
RunnerThread
,
this
,
ctx_copy
));
}
}
}
}
...
@@ -660,7 +662,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
...
@@ -660,7 +662,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
std
::
unique_ptr
<
IteratorBase
>
input_impl_
;
std
::
unique_ptr
<
IteratorBase
>
input_impl_
;
// Buffer for storing the (intermediate) batch results.
// Buffer for storing the (intermediate) batch results.
std
::
deque
<
std
::
shared_ptr
<
BatchResult
>>
batch_results_
GUARDED_BY
(
*
mu_
);
std
::
deque
<
std
::
shared_ptr
<
BatchResult
>>
batch_results_
GUARDED_BY
(
*
mu_
);
std
::
unique_ptr
<
Thread
>
runner_thread_
GUARDED_BY
(
*
mu_
);
std
::
unique_ptr
<
BackgroundWorker
>
runner_thread_
GUARDED_BY
(
*
mu_
);
bool
cancelled_
GUARDED_BY
(
*
mu_
)
=
false
;
bool
cancelled_
GUARDED_BY
(
*
mu_
)
=
false
;
};
};
...
...
tensorflow/core/kernels/data/model_dataset_op.cc
浏览文件 @
c9bdd393
...
@@ -18,6 +18,7 @@ limitations under the License.
...
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/util/ptr_util.h"
namespace
tensorflow
{
namespace
tensorflow
{
namespace
data
{
namespace
data
{
...
@@ -126,9 +127,10 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
...
@@ -126,9 +127,10 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED
(
mu_
)
{
EXCLUSIVE_LOCKS_REQUIRED
(
mu_
)
{
if
(
!
optimize_thread_
)
{
if
(
!
optimize_thread_
)
{
std
::
shared_ptr
<
IteratorContext
>
new_ctx
(
new
IteratorContext
(
*
ctx
));
std
::
shared_ptr
<
IteratorContext
>
new_ctx
(
new
IteratorContext
(
*
ctx
));
optimize_thread_
.
reset
(
ctx
->
env
()
->
StartThread
(
optimize_thread_
=
{},
"optimize_thread"
,
MakeUnique
<
BackgroundWorker
>
(
ctx
->
env
(),
"optimize_thread"
);
[
this
,
new_ctx
]()
{
OptimizeThread
(
new_ctx
);
}));
optimize_thread_
->
Schedule
(
[
this
,
new_ctx
]()
{
OptimizeThread
(
new_ctx
);
});
}
}
return
Status
::
OK
();
return
Status
::
OK
();
}
}
...
@@ -167,7 +169,7 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
...
@@ -167,7 +169,7 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
mutex
mu_
;
mutex
mu_
;
condition_variable
cond_var_
;
condition_variable
cond_var_
;
std
::
shared_ptr
<
model
::
Model
>
model_
;
std
::
shared_ptr
<
model
::
Model
>
model_
;
std
::
unique_ptr
<
Thread
>
optimize_thread_
GUARDED_BY
(
mu_
);
std
::
unique_ptr
<
BackgroundWorker
>
optimize_thread_
GUARDED_BY
(
mu_
);
bool
cancelled_
GUARDED_BY
(
mu_
)
=
false
;
bool
cancelled_
GUARDED_BY
(
mu_
)
=
false
;
std
::
unique_ptr
<
IteratorBase
>
input_impl_
GUARDED_BY
(
mu_
);
std
::
unique_ptr
<
IteratorBase
>
input_impl_
GUARDED_BY
(
mu_
);
};
};
...
...
tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
浏览文件 @
c9bdd393
...
@@ -26,6 +26,7 @@ limitations under the License.
...
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/util/ptr_util.h"
namespace
tensorflow
{
namespace
tensorflow
{
namespace
data
{
namespace
data
{
...
@@ -481,9 +482,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
...
@@ -481,9 +482,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
worker_threads_
.
reserve
(
dataset
()
->
num_threads
());
worker_threads_
.
reserve
(
dataset
()
->
num_threads
());
for
(
size_t
i
=
0
;
i
<
dataset
()
->
num_threads
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
dataset
()
->
num_threads
();
++
i
)
{
std
::
shared_ptr
<
IteratorContext
>
new_ctx
(
new
IteratorContext
(
*
ctx
));
std
::
shared_ptr
<
IteratorContext
>
new_ctx
(
new
IteratorContext
(
*
ctx
));
worker_threads_
.
emplace_back
(
ctx
->
env
()
->
StartThread
(
worker_threads_
.
emplace_back
(
{},
"worker_thread"
,
MakeUnique
<
BackgroundWorker
>
(
ctx
->
env
(),
"worker_thread"
));
[
this
,
new_ctx
,
i
]()
{
WorkerThread
(
new_ctx
,
i
);
}));
worker_threads_
.
back
()
->
Schedule
(
[
this
,
new_ctx
,
i
]()
{
WorkerThread
(
new_ctx
,
i
);
});
}
}
}
}
return
Status
::
OK
();
return
Status
::
OK
();
...
@@ -580,9 +582,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
...
@@ -580,9 +582,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
}
workers_
[
i
].
SetInputs
(
s
,
std
::
move
(
args
));
workers_
[
i
].
SetInputs
(
s
,
std
::
move
(
args
));
std
::
shared_ptr
<
IteratorContext
>
new_ctx
(
new
IteratorContext
(
*
ctx
));
std
::
shared_ptr
<
IteratorContext
>
new_ctx
(
new
IteratorContext
(
*
ctx
));
worker_threads_
.
emplace_back
(
ctx
->
env
()
->
StartThread
(
worker_threads_
.
emplace_back
(
{},
"worker_thread"
,
MakeUnique
<
BackgroundWorker
>
(
ctx
->
env
(),
"worker_thread"
));
[
this
,
new_ctx
,
i
]()
{
WorkerThread
(
new_ctx
,
i
);
}));
worker_threads_
.
back
()
->
Schedule
(
[
this
,
new_ctx
,
i
]()
{
WorkerThread
(
new_ctx
,
i
);
});
if
(
i
<
dataset
()
->
cycle_length_
)
{
if
(
i
<
dataset
()
->
cycle_length_
)
{
interleave_indices_
.
push_back
(
i
);
interleave_indices_
.
push_back
(
i
);
}
else
{
}
else
{
...
@@ -1047,7 +1050,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
...
@@ -1047,7 +1050,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// The worker threads. This must be last to ensure the
// The worker threads. This must be last to ensure the
// threads have exited before any other members are deallocated.
// threads have exited before any other members are deallocated.
// TODO(b/65178177): Avoid allocating additional threads.
// TODO(b/65178177): Avoid allocating additional threads.
std
::
vector
<
std
::
unique_ptr
<
Thread
>>
worker_threads_
GUARDED_BY
(
mu_
);
std
::
vector
<
std
::
unique_ptr
<
BackgroundWorker
>>
worker_threads_
GUARDED_BY
(
mu_
);
};
};
const
DatasetBase
*
const
input_
;
const
DatasetBase
*
const
input_
;
...
@@ -1389,9 +1393,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
...
@@ -1389,9 +1393,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
EXCLUSIVE_LOCKS_REQUIRED
(
*
mu_
)
{
EXCLUSIVE_LOCKS_REQUIRED
(
*
mu_
)
{
if
(
!
runner_thread_
)
{
if
(
!
runner_thread_
)
{
std
::
shared_ptr
<
IteratorContext
>
new_ctx
(
new
IteratorContext
(
*
ctx
));
std
::
shared_ptr
<
IteratorContext
>
new_ctx
(
new
IteratorContext
(
*
ctx
));
runner_thread_
.
reset
(
ctx
->
env
()
->
StartThread
(
runner_thread_
=
{},
"runner_thread"
,
MakeUnique
<
BackgroundWorker
>
(
ctx
->
env
(),
"runner_thread"
);
[
this
,
new_ctx
]()
{
RunnerThread
(
new_ctx
);
}));
runner_thread_
->
Schedule
(
[
this
,
new_ctx
]()
{
RunnerThread
(
new_ctx
);
});
}
}
}
}
...
@@ -1645,7 +1650,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
...
@@ -1645,7 +1650,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
int64
num_calls_
GUARDED_BY
(
*
mu_
)
=
0
;
int64
num_calls_
GUARDED_BY
(
*
mu_
)
=
0
;
std
::
unique_ptr
<
thread
::
ThreadPool
>
thread_pool_
;
std
::
unique_ptr
<
thread
::
ThreadPool
>
thread_pool_
;
std
::
unique_ptr
<
Thread
>
runner_thread_
GUARDED_BY
(
*
mu_
);
std
::
unique_ptr
<
BackgroundWorker
>
runner_thread_
GUARDED_BY
(
*
mu_
);
// Identifies whether background activity should be cancelled.
// Identifies whether background activity should be cancelled.
bool
cancelled_
GUARDED_BY
(
*
mu_
)
=
false
;
bool
cancelled_
GUARDED_BY
(
*
mu_
)
=
false
;
...
...
tensorflow/core/kernels/data/parallel_map_iterator.cc
浏览文件 @
c9bdd393
...
@@ -22,6 +22,7 @@ limitations under the License.
...
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/util/ptr_util.h"
namespace
tensorflow
{
namespace
tensorflow
{
namespace
data
{
namespace
data
{
...
@@ -180,9 +181,10 @@ class ParallelMapIterator : public DatasetBaseIterator {
...
@@ -180,9 +181,10 @@ class ParallelMapIterator : public DatasetBaseIterator {
EXCLUSIVE_LOCKS_REQUIRED
(
*
mu_
)
{
EXCLUSIVE_LOCKS_REQUIRED
(
*
mu_
)
{
if
(
!
runner_thread_
)
{
if
(
!
runner_thread_
)
{
std
::
shared_ptr
<
IteratorContext
>
ctx_copy
(
new
IteratorContext
(
*
ctx
));
std
::
shared_ptr
<
IteratorContext
>
ctx_copy
(
new
IteratorContext
(
*
ctx
));
runner_thread_
.
reset
(
ctx
->
env
()
->
StartThread
(
runner_thread_
=
{},
"runner_thread"
,
MakeUnique
<
BackgroundWorker
>
(
ctx
->
env
(),
"runner_thread"
);
std
::
bind
(
&
ParallelMapIterator
::
RunnerThread
,
this
,
ctx_copy
)));
runner_thread_
->
Schedule
(
std
::
bind
(
&
ParallelMapIterator
::
RunnerThread
,
this
,
ctx_copy
));
}
}
}
}
...
@@ -330,7 +332,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
...
@@ -330,7 +332,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
// Buffer for storing the invocation results.
// Buffer for storing the invocation results.
std
::
deque
<
std
::
shared_ptr
<
InvocationResult
>>
invocation_results_
std
::
deque
<
std
::
shared_ptr
<
InvocationResult
>>
invocation_results_
GUARDED_BY
(
*
mu_
);
GUARDED_BY
(
*
mu_
);
std
::
unique_ptr
<
Thread
>
runner_thread_
GUARDED_BY
(
*
mu_
);
std
::
unique_ptr
<
BackgroundWorker
>
runner_thread_
GUARDED_BY
(
*
mu_
);
bool
cancelled_
GUARDED_BY
(
*
mu_
)
=
false
;
bool
cancelled_
GUARDED_BY
(
*
mu_
)
=
false
;
};
};
...
...
tensorflow/core/kernels/data/prefetch_dataset_op.cc
浏览文件 @
c9bdd393
...
@@ -22,6 +22,7 @@ limitations under the License.
...
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/util/ptr_util.h"
namespace
tensorflow
{
namespace
tensorflow
{
namespace
data
{
namespace
data
{
...
@@ -256,10 +257,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
...
@@ -256,10 +257,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
Status
EnsurePrefetchThreadStarted
(
IteratorContext
*
ctx
)
Status
EnsurePrefetchThreadStarted
(
IteratorContext
*
ctx
)
EXCLUSIVE_LOCKS_REQUIRED
(
mu_
)
{
EXCLUSIVE_LOCKS_REQUIRED
(
mu_
)
{
if
(
!
prefetch_thread_
)
{
if
(
!
prefetch_thread_
)
{
prefetch_thread_
=
MakeUnique
<
BackgroundWorker
>
(
ctx
->
env
(),
"prefetch_thread"
);
std
::
shared_ptr
<
IteratorContext
>
new_ctx
(
new
IteratorContext
(
*
ctx
));
std
::
shared_ptr
<
IteratorContext
>
new_ctx
(
new
IteratorContext
(
*
ctx
));
prefetch_thread_
.
reset
(
ctx
->
env
()
->
StartThread
(
prefetch_thread_
->
Schedule
(
{},
"prefetch_thread"
,
[
this
,
new_ctx
]()
{
PrefetchThread
(
new_ctx
);
});
[
this
,
new_ctx
]()
{
PrefetchThread
(
new_ctx
);
}));
}
}
return
Status
::
OK
();
return
Status
::
OK
();
}
}
...
@@ -363,7 +365,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
...
@@ -363,7 +365,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
string
prefix_end_
;
string
prefix_end_
;
PrefetchAutotuner
auto_tuner_
GUARDED_BY
(
mu_
);
PrefetchAutotuner
auto_tuner_
GUARDED_BY
(
mu_
);
std
::
deque
<
BufferElement
>
buffer_
GUARDED_BY
(
mu_
);
std
::
deque
<
BufferElement
>
buffer_
GUARDED_BY
(
mu_
);
std
::
unique_ptr
<
Thread
>
prefetch_thread_
GUARDED_BY
(
mu_
);
std
::
unique_ptr
<
BackgroundWorker
>
prefetch_thread_
GUARDED_BY
(
mu_
);
bool
cancelled_
GUARDED_BY
(
mu_
)
=
false
;
bool
cancelled_
GUARDED_BY
(
mu_
)
=
false
;
bool
prefetch_thread_finished_
GUARDED_BY
(
mu_
)
=
false
;
bool
prefetch_thread_finished_
GUARDED_BY
(
mu_
)
=
false
;
};
};
...
...
tensorflow/core/kernels/data/writer_ops.cc
浏览文件 @
c9bdd393
...
@@ -29,10 +29,10 @@ class ToTFRecordOp : public AsyncOpKernel {
...
@@ -29,10 +29,10 @@ class ToTFRecordOp : public AsyncOpKernel {
public:
public:
explicit
ToTFRecordOp
(
OpKernelConstruction
*
ctx
)
explicit
ToTFRecordOp
(
OpKernelConstruction
*
ctx
)
:
AsyncOpKernel
(
ctx
),
:
AsyncOpKernel
(
ctx
),
thread_pool_
(
new
thread
::
ThreadPool
(
background_worker_
(
ctx
->
env
(),
ThreadOptions
(),
ctx
->
env
(),
strings
::
StrCat
(
"to_tf_record_
_op_"
,
SanitizeThreadSuffix
(
name
())),
strings
::
StrCat
(
"to_tf_record_
op_"
,
SanitizeThreadSuffix
(
name
())))
{
1
/* num_threads */
,
false
/* low_latency_hint */
))
{
}
}
template
<
typename
T
>
template
<
typename
T
>
Status
ParseScalarArgument
(
OpKernelContext
*
ctx
,
Status
ParseScalarArgument
(
OpKernelContext
*
ctx
,
...
@@ -50,7 +50,7 @@ class ToTFRecordOp : public AsyncOpKernel {
...
@@ -50,7 +50,7 @@ class ToTFRecordOp : public AsyncOpKernel {
// The call to `iterator->GetNext()` may block and depend on an
// The call to `iterator->GetNext()` may block and depend on an
// inter-op thread pool thread, so we issue the call from the
// inter-op thread pool thread, so we issue the call from the
// owned thread pool.
// owned thread pool.
thread_pool_
->
Schedule
([
this
,
ctx
,
done
]()
{
background_worker_
.
Schedule
([
this
,
ctx
,
done
]()
{
string
filename
;
string
filename
;
OP_REQUIRES_OK_ASYNC
(
OP_REQUIRES_OK_ASYNC
(
ctx
,
ParseScalarArgument
<
string
>
(
ctx
,
"filename"
,
&
filename
),
done
);
ctx
,
ParseScalarArgument
<
string
>
(
ctx
,
"filename"
,
&
filename
),
done
);
...
@@ -97,7 +97,7 @@ class ToTFRecordOp : public AsyncOpKernel {
...
@@ -97,7 +97,7 @@ class ToTFRecordOp : public AsyncOpKernel {
}
}
private:
private:
std
::
unique_ptr
<
thread
::
ThreadPool
>
thread_pool
_
;
BackgroundWorker
background_worker
_
;
};
};
REGISTER_KERNEL_BUILDER
(
Name
(
"DatasetToTFRecord"
).
Device
(
DEVICE_CPU
),
REGISTER_KERNEL_BUILDER
(
Name
(
"DatasetToTFRecord"
).
Device
(
DEVICE_CPU
),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录