Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
b3d5ec90
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
b3d5ec90
编写于
10月 20, 2017
作者:
V
Vijay Vasudevan
提交者:
GitHub
10月 20, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13866 from vrv/branch_172924803
Branch 172924803
上级
a528ccdb
cf336b33
变更
128
展开全部
隐藏空白更改
内联
并排
Showing
128 changed file
with
6849 addition
and
2251 deletion
+6849
-2251
WORKSPACE
WORKSPACE
+1
-1
tensorflow/BUILD
tensorflow/BUILD
+2
-1
tensorflow/c/eager/BUILD
tensorflow/c/eager/BUILD
+2
-1
tensorflow/c/eager/c_api.cc
tensorflow/c/eager/c_api.cc
+68
-13
tensorflow/c/eager/c_api.h
tensorflow/c/eager/c_api.h
+34
-2
tensorflow/c/eager/c_api_internal.h
tensorflow/c/eager/c_api_internal.h
+7
-0
tensorflow/c/eager/c_api_test.cc
tensorflow/c/eager/c_api_test.cc
+64
-18
tensorflow/compiler/xla/client/computation_builder.h
tensorflow/compiler/xla/client/computation_builder.h
+26
-15
tensorflow/compiler/xla/layout_util.cc
tensorflow/compiler/xla/layout_util.cc
+4
-0
tensorflow/compiler/xla/layout_util.h
tensorflow/compiler/xla/layout_util.h
+1
-0
tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+3
-3
tensorflow/compiler/xla/literal_util.h
tensorflow/compiler/xla/literal_util.h
+54
-67
tensorflow/compiler/xla/protobuf_util.cc
tensorflow/compiler/xla/protobuf_util.cc
+0
-25
tensorflow/compiler/xla/protobuf_util.h
tensorflow/compiler/xla/protobuf_util.h
+4
-9
tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/BUILD
+23
-0
tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+13
-14
tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
...flow/compiler/xla/service/cpu/parallel_task_assignment.cc
+2
-0
tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+5
-5
tensorflow/compiler/xla/service/hlo_computation.cc
tensorflow/compiler/xla/service/hlo_computation.cc
+2
-2
tensorflow/compiler/xla/service/hlo_runner.cc
tensorflow/compiler/xla/service/hlo_runner.cc
+199
-0
tensorflow/compiler/xla/service/hlo_runner.h
tensorflow/compiler/xla/service/hlo_runner.h
+100
-0
tensorflow/compiler/xla/service/transpose_folding.cc
tensorflow/compiler/xla/service/transpose_folding.cc
+72
-33
tensorflow/compiler/xla/service/transpose_folding_test.cc
tensorflow/compiler/xla/service/transpose_folding_test.cc
+17
-11
tensorflow/compiler/xla/service/user_computation.cc
tensorflow/compiler/xla/service/user_computation.cc
+10
-2
tensorflow/compiler/xla/shape_util.cc
tensorflow/compiler/xla/shape_util.cc
+31
-14
tensorflow/compiler/xla/tests/BUILD
tensorflow/compiler/xla/tests/BUILD
+1
-11
tensorflow/compiler/xla/tests/hlo_test_base.cc
tensorflow/compiler/xla/tests/hlo_test_base.cc
+5
-109
tensorflow/compiler/xla/tests/hlo_test_base.h
tensorflow/compiler/xla/tests/hlo_test_base.h
+7
-19
tensorflow/compiler/xla/tools/BUILD
tensorflow/compiler/xla/tools/BUILD
+12
-0
tensorflow/compiler/xla/tools/hlo_proto_to_json.cc
tensorflow/compiler/xla/tools/hlo_proto_to_json.cc
+91
-0
tensorflow/compiler/xla/tools/parser/BUILD
tensorflow/compiler/xla/tools/parser/BUILD
+84
-0
tensorflow/compiler/xla/tools/parser/README.md
tensorflow/compiler/xla/tools/parser/README.md
+69
-0
tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
+270
-0
tensorflow/compiler/xla/tools/parser/hlo_lexer.h
tensorflow/compiler/xla/tools/parser/hlo_lexer.h
+108
-0
tensorflow/compiler/xla/tools/parser/hlo_parser.cc
tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+502
-0
tensorflow/compiler/xla/tools/parser/hlo_parser.h
tensorflow/compiler/xla/tools/parser/hlo_parser.h
+37
-0
tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+240
-0
tensorflow/compiler/xla/tools/parser/hlo_token.h
tensorflow/compiler/xla/tools/parser/hlo_token.h
+58
-0
tensorflow/compiler/xla/xla.proto
tensorflow/compiler/xla/xla.proto
+2
-2
tensorflow/contrib/batching/BUILD
tensorflow/contrib/batching/BUILD
+22
-0
tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
...orflow/contrib/batching/adaptive_shared_batch_scheduler.h
+463
-0
tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc
.../contrib/batching/adaptive_shared_batch_scheduler_test.cc
+438
-0
tensorflow/contrib/batching/batch_scheduler.h
tensorflow/contrib/batching/batch_scheduler.h
+1
-1
tensorflow/contrib/cmake/external/cub.cmake
tensorflow/contrib/cmake/external/cub.cmake
+1
-1
tensorflow/contrib/cmake/external/gif.cmake
tensorflow/contrib/cmake/external/gif.cmake
+1
-1
tensorflow/contrib/cmake/external/jpeg.cmake
tensorflow/contrib/cmake/external/jpeg.cmake
+1
-1
tensorflow/contrib/cmake/external/lmdb.cmake
tensorflow/contrib/cmake/external/lmdb.cmake
+1
-1
tensorflow/contrib/cmake/external/snappy.cmake
tensorflow/contrib/cmake/external/snappy.cmake
+1
-1
tensorflow/contrib/eager/python/BUILD
tensorflow/contrib/eager/python/BUILD
+9
-2
tensorflow/contrib/eager/python/evaluator_test.py
tensorflow/contrib/eager/python/evaluator_test.py
+1
-1
tensorflow/contrib/eager/python/metrics_impl.py
tensorflow/contrib/eager/python/metrics_impl.py
+108
-48
tensorflow/contrib/eager/python/metrics_test.py
tensorflow/contrib/eager/python/metrics_test.py
+51
-0
tensorflow/contrib/eager/python/saver_test.py
tensorflow/contrib/eager/python/saver_test.py
+39
-8
tensorflow/contrib/eager/python/tfe.py
tensorflow/contrib/eager/python/tfe.py
+2
-0
tensorflow/contrib/factorization/g3doc/kmeans.md
tensorflow/contrib/factorization/g3doc/kmeans.md
+8
-4
tensorflow/contrib/factorization/kernels/clustering_ops.cc
tensorflow/contrib/factorization/kernels/clustering_ops.cc
+52
-0
tensorflow/contrib/factorization/kernels/clustering_ops_test.cc
...flow/contrib/factorization/kernels/clustering_ops_test.cc
+56
-0
tensorflow/contrib/factorization/ops/clustering_ops.cc
tensorflow/contrib/factorization/ops/clustering_ops.cc
+19
-0
tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py
.../factorization/python/kernel_tests/clustering_ops_test.py
+57
-0
tensorflow/contrib/factorization/python/ops/clustering_ops.py
...orflow/contrib/factorization/python/ops/clustering_ops.py
+115
-12
tensorflow/contrib/framework/__init__.py
tensorflow/contrib/framework/__init__.py
+1
-0
tensorflow/contrib/framework/python/ops/arg_scope.py
tensorflow/contrib/framework/python/ops/arg_scope.py
+4
-3
tensorflow/contrib/learn/python/learn/learn_io/graph_io.py
tensorflow/contrib/learn/python/learn/learn_io/graph_io.py
+11
-3
tensorflow/contrib/makefile/Makefile
tensorflow/contrib/makefile/Makefile
+1
-0
tensorflow/contrib/makefile/download_dependencies.sh
tensorflow/contrib/makefile/download_dependencies.sh
+4
-4
tensorflow/contrib/metrics/python/ops/metric_ops.py
tensorflow/contrib/metrics/python/ops/metric_ops.py
+382
-209
tensorflow/contrib/metrics/python/ops/metric_ops_test.py
tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+26
-32
tensorflow/contrib/quantize/BUILD
tensorflow/contrib/quantize/BUILD
+32
-1
tensorflow/contrib/quantize/python/copy_graph_test.py
tensorflow/contrib/quantize/python/copy_graph_test.py
+1
-1
tensorflow/contrib/quantize/python/fold_batch_norms.py
tensorflow/contrib/quantize/python/fold_batch_norms.py
+265
-4
tensorflow/contrib/quantize/python/fold_batch_norms_test.py
tensorflow/contrib/quantize/python/fold_batch_norms_test.py
+113
-259
tensorflow/contrib/quantize/python/graph_matcher.py
tensorflow/contrib/quantize/python/graph_matcher.py
+200
-0
tensorflow/contrib/quantize/python/graph_matcher_test.py
tensorflow/contrib/quantize/python/graph_matcher_test.py
+130
-0
tensorflow/contrib/quantize/python/quantize_parameterized_test.py
...ow/contrib/quantize/python/quantize_parameterized_test.py
+114
-98
tensorflow/contrib/rnn/BUILD
tensorflow/contrib/rnn/BUILD
+2
-0
tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+221
-143
tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
+5
-2
tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc
tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc
+16
-9
tensorflow/contrib/seq2seq/ops/beam_search_ops.cc
tensorflow/contrib/seq2seq/ops/beam_search_ops.cc
+7
-4
tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py
...ntrib/seq2seq/python/kernel_tests/beam_search_ops_test.py
+17
-14
tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+664
-649
tensorflow/contrib/training/python/training/bucket_ops.py
tensorflow/contrib/training/python/training/bucket_ops.py
+2
-2
tensorflow/core/BUILD
tensorflow/core/BUILD
+1
-1
tensorflow/core/framework/api_def.proto
tensorflow/core/framework/api_def.proto
+3
-2
tensorflow/core/framework/op_gen_lib.cc
tensorflow/core/framework/op_gen_lib.cc
+18
-0
tensorflow/core/framework/op_gen_lib_test.cc
tensorflow/core/framework/op_gen_lib_test.cc
+47
-3
tensorflow/core/graph/graph_constructor.cc
tensorflow/core/graph/graph_constructor.cc
+10
-4
tensorflow/core/grappler/grappler_item_builder.cc
tensorflow/core/grappler/grappler_item_builder.cc
+6
-1
tensorflow/core/kernels/mkl_transpose_op.cc
tensorflow/core/kernels/mkl_transpose_op.cc
+63
-57
tensorflow/core/kernels/transpose_functor.h
tensorflow/core/kernels/transpose_functor.h
+13
-4
tensorflow/core/platform/default/build_config.bzl
tensorflow/core/platform/default/build_config.bzl
+1
-1
tensorflow/core/platform/posix/port.cc
tensorflow/core/platform/posix/port.cc
+4
-4
tensorflow/core/platform/s3/BUILD
tensorflow/core/platform/s3/BUILD
+0
-0
tensorflow/core/platform/s3/s3_crypto.cc
tensorflow/core/platform/s3/s3_crypto.cc
+1
-1
tensorflow/core/platform/s3/s3_crypto.h
tensorflow/core/platform/s3/s3_crypto.h
+0
-0
tensorflow/core/platform/s3/s3_file_system.cc
tensorflow/core/platform/s3/s3_file_system.cc
+2
-2
tensorflow/core/platform/s3/s3_file_system.h
tensorflow/core/platform/s3/s3_file_system.h
+0
-0
tensorflow/core/platform/s3/s3_file_system_test.cc
tensorflow/core/platform/s3/s3_file_system_test.cc
+1
-1
tensorflow/core/platform/windows/port.cc
tensorflow/core/platform/windows/port.cc
+4
-4
tensorflow/examples/learn/iris.py
tensorflow/examples/learn/iris.py
+74
-27
tensorflow/examples/learn/random_forest_mnist.py
tensorflow/examples/learn/random_forest_mnist.py
+36
-29
tensorflow/examples/learn/text_classification_character_rnn.py
...rflow/examples/learn/text_classification_character_rnn.py
+5
-14
tensorflow/python/BUILD
tensorflow/python/BUILD
+2
-0
tensorflow/python/eager/backprop.py
tensorflow/python/eager/backprop.py
+4
-7
tensorflow/python/eager/backprop_test.py
tensorflow/python/eager/backprop_test.py
+11
-2
tensorflow/python/eager/context.py
tensorflow/python/eager/context.py
+10
-6
tensorflow/python/eager/function.py
tensorflow/python/eager/function.py
+47
-14
tensorflow/python/eager/function_test.py
tensorflow/python/eager/function_test.py
+20
-2
tensorflow/python/eager/graph_callable.py
tensorflow/python/eager/graph_callable.py
+10
-0
tensorflow/python/estimator/export/export.py
tensorflow/python/estimator/export/export.py
+50
-6
tensorflow/python/estimator/export/export_output.py
tensorflow/python/estimator/export/export_output.py
+6
-4
tensorflow/python/estimator/export/export_output_test.py
tensorflow/python/estimator/export/export_output_test.py
+6
-8
tensorflow/python/framework/test_util.py
tensorflow/python/framework/test_util.py
+62
-2
tensorflow/python/framework/test_util_test.py
tensorflow/python/framework/test_util_test.py
+71
-0
tensorflow/python/kernel_tests/BUILD
tensorflow/python/kernel_tests/BUILD
+2
-0
tensorflow/python/kernel_tests/qr_op_test.py
tensorflow/python/kernel_tests/qr_op_test.py
+62
-4
tensorflow/python/kernel_tests/resource_variable_ops_test.py
tensorflow/python/kernel_tests/resource_variable_ops_test.py
+32
-11
tensorflow/python/kernel_tests/rnn_test.py
tensorflow/python/kernel_tests/rnn_test.py
+65
-26
tensorflow/python/ops/linalg_grad.py
tensorflow/python/ops/linalg_grad.py
+36
-6
tensorflow/python/ops/resource_variable_ops.py
tensorflow/python/ops/resource_variable_ops.py
+27
-0
tensorflow/python/ops/rnn.py
tensorflow/python/ops/rnn.py
+48
-28
tensorflow/python/pywrap_tfe.i
tensorflow/python/pywrap_tfe.i
+15
-1
tensorflow/python/saved_model/signature_def_utils_impl.py
tensorflow/python/saved_model/signature_def_utils_impl.py
+18
-5
tensorflow/python/training/input.py
tensorflow/python/training/input.py
+5
-21
tensorflow/python/training/saver.py
tensorflow/python/training/saver.py
+8
-1
tensorflow/python/training/saver_test.py
tensorflow/python/training/saver_test.py
+2
-1
tensorflow/tools/ci_build/ci_sanity.sh
tensorflow/tools/ci_build/ci_sanity.sh
+1
-0
tensorflow/workspace.bzl
tensorflow/workspace.bzl
+46
-46
未找到文件。
WORKSPACE
浏览文件 @
b3d5ec90
...
...
@@ -5,7 +5,7 @@ http_archive(
sha256
=
"110fe68753413777944b473c25eed6368c4a0487cee23a7bac1b13cc49d3e257"
,
strip_prefix
=
"rules_closure-4af89ef1db659eb41f110df189b67d4cf14073e1"
,
urls
=
[
"http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz"
,
"http
s
://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz"
,
"https://github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz"
,
# 2017-08-28
],
)
...
...
tensorflow/BUILD
浏览文件 @
b3d5ec90
...
...
@@ -348,6 +348,7 @@ filegroup(
"//tensorflow/compiler/xla/service/llvm_ir:all_files"
,
"//tensorflow/compiler/xla/tests:all_files"
,
"//tensorflow/compiler/xla/tools:all_files"
,
"//tensorflow/compiler/xla/tools/parser:all_files"
,
"//tensorflow/contrib:all_files"
,
"//tensorflow/contrib/all_reduce:all_files"
,
"//tensorflow/contrib/android:all_files"
,
...
...
@@ -421,7 +422,6 @@ filegroup(
"//tensorflow/contrib/remote_fused_graph/pylib:all_files"
,
"//tensorflow/contrib/resampler:all_files"
,
"//tensorflow/contrib/rnn:all_files"
,
"//tensorflow/contrib/s3:all_files"
,
"//tensorflow/contrib/saved_model:all_files"
,
"//tensorflow/contrib/saved_model/cc/saved_model:all_files"
,
"//tensorflow/contrib/seq2seq:all_files"
,
...
...
@@ -475,6 +475,7 @@ filegroup(
"//tensorflow/core/platform/cloud:all_files"
,
"//tensorflow/core/platform/default/build_config:all_files"
,
"//tensorflow/core/platform/hadoop:all_files"
,
"//tensorflow/core/platform/s3:all_files"
,
"//tensorflow/core/profiler:all_files"
,
"//tensorflow/core/profiler/internal:all_files"
,
"//tensorflow/core/profiler/internal/advisor:all_files"
,
...
...
tensorflow/c/eager/BUILD
浏览文件 @
b3d5ec90
...
...
@@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0
load
(
"//tensorflow:tensorflow.bzl"
,
"tf_cuda_cc_test"
,
"tf_cc_test"
,
"tf_copts"
,
"tf_cuda_library"
,
...
...
@@ -50,7 +51,7 @@ tf_cuda_library(
],
)
tf_cc_test
(
tf_c
uda_c
c_test
(
name
=
"c_api_test"
,
srcs
=
[
"c_api_test.cc"
],
deps
=
[
...
...
tensorflow/c/eager/c_api.cc
浏览文件 @
b3d5ec90
...
...
@@ -54,9 +54,23 @@ string DeviceName(tensorflow::Device* d) {
extern
"C"
{
TFE_Context
*
TFE_NewContext
(
const
TF_SessionOptions
*
opts
,
TF_Status
*
status
)
{
TFE_ContextOptions
*
TFE_NewContextOptions
()
{
return
new
TFE_ContextOptions
;
}
void
TFE_ContextOptionsSetConfig
(
TFE_ContextOptions
*
options
,
const
void
*
proto
,
size_t
proto_len
,
TF_Status
*
status
)
{
TF_SetConfig
(
&
options
->
session_options
,
proto
,
proto_len
,
status
);
}
void
TFE_ContextOptionsSetDevicePlacementPolicy
(
TFE_ContextOptions
*
options
,
TFE_ContextDevicePlacementPolicy
policy
)
{
options
->
policy
=
policy
;
}
void
TFE_DeleteContextOptions
(
TFE_ContextOptions
*
options
)
{
delete
options
;
}
TFE_Context
*
TFE_NewContext
(
const
TFE_ContextOptions
*
opts
,
TF_Status
*
status
)
{
TF_Graph
*
graph
=
TF_NewGraph
();
TF_Session
*
session
=
TF_NewSession
(
graph
,
opt
s
,
status
);
TF_Session
*
session
=
TF_NewSession
(
graph
,
&
opts
->
session_option
s
,
status
);
if
(
status
->
status
.
ok
())
{
if
(
session
->
device_mgr
==
nullptr
||
session
->
devices
.
empty
())
{
status
->
status
=
tensorflow
::
errors
::
InvalidArgument
(
...
...
@@ -71,9 +85,10 @@ TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) {
}
TFE_Context
*
ret
=
new
TFE_Context
(
session
);
ret
->
policy
=
opts
->
policy
;
ret
->
pflr
.
reset
(
new
tensorflow
::
ProcessFunctionLibraryRuntime
(
ret
->
session
->
device_mgr
,
opts
->
options
.
env
,
TF_GRAPH_DEF_VERSION
,
&
ret
->
func_lib_def
,
{}));
ret
->
session
->
device_mgr
,
opts
->
session_options
.
options
.
env
,
TF_GRAPH_DEF_VERSION
,
&
ret
->
func_lib_def
,
{}));
ret
->
rendezvous
=
new
tensorflow
::
IntraProcessRendezvous
(
ret
->
session
->
device_mgr
);
...
...
@@ -408,8 +423,10 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
namespace
{
tensorflow
::
Status
ValidateInputTypeAndPlacement
(
tensorflow
::
Device
*
host_device
,
tensorflow
::
Device
*
op_device
,
TFE_Op
*
op
,
const
tensorflow
::
OpKernel
*
kernel
)
{
TFE_Context
*
ctx
,
tensorflow
::
Device
*
host_device
,
tensorflow
::
Device
*
op_device
,
TFE_Op
*
op
,
const
tensorflow
::
OpKernel
*
kernel
,
std
::
vector
<
TFE_TensorHandle
*>*
copied_tensors
)
{
const
tensorflow
::
MemoryTypeVector
&
memtypes
=
kernel
->
input_memory_types
();
if
(
memtypes
.
size
()
!=
op
->
inputs
.
size
())
{
return
tensorflow
::
errors
::
InvalidArgument
(
...
...
@@ -421,11 +438,42 @@ tensorflow::Status ValidateInputTypeAndPlacement(
const
tensorflow
::
Device
*
actual_device
=
op
->
input_devices
[
i
]
==
nullptr
?
host_device
:
op
->
input_devices
[
i
];
if
(
expected_device
!=
actual_device
)
{
return
tensorflow
::
errors
::
InvalidArgument
(
"cannot compute "
,
op
->
name
,
" as input #"
,
i
,
" was expected to be on "
,
expected_device
->
name
(),
" but is actually on "
,
actual_device
->
name
(),
" (operation running on "
,
op_device
->
name
(),
")"
);
switch
(
ctx
->
policy
)
{
case
TFE_DEVICE_PLACEMENT_EXPLICIT
:
return
tensorflow
::
errors
::
InvalidArgument
(
"cannot compute "
,
op
->
name
,
" as input #"
,
i
,
" was expected to be on "
,
expected_device
->
name
(),
" but is actually on "
,
actual_device
->
name
(),
" (operation running on "
,
op_device
->
name
(),
")"
);
case
TFE_DEVICE_PLACEMENT_WARN
:
LOG
(
WARNING
)
<<
"before computing "
<<
op
->
name
<<
" input #"
<<
i
<<
" was expected to be on "
<<
expected_device
->
name
()
<<
" but is actually on "
<<
actual_device
->
name
()
<<
" (operation running on "
<<
op_device
->
name
()
<<
"). This triggers a copy which can be a performance "
"bottleneck."
;
break
;
case
TFE_DEVICE_PLACEMENT_SILENT
:
// Do nothing.
break
;
}
// We are only here if the policy is warn or silent copies, so we should
// trigger a copy.
TFE_TensorHandle
original
{
op
->
inputs
[
i
],
op
->
input_devices
[
i
]};
TF_Status
*
s
=
TF_NewStatus
();
TFE_TensorHandle
*
copied_tensor
=
TFE_TensorHandleCopyToDevice
(
&
original
,
ctx
,
expected_device
->
name
().
c_str
(),
s
);
if
(
!
s
->
status
.
ok
())
{
tensorflow
::
Status
status
=
s
->
status
;
delete
s
;
return
tensorflow
::
errors
::
Internal
(
"Failed copying input tensor from "
,
actual_device
->
name
(),
" to "
,
expected_device
->
name
(),
" in order to run "
,
op
->
name
,
": "
,
status
.
error_message
());
}
op
->
inputs
[
i
]
=
copied_tensor
->
t
;
copied_tensors
->
push_back
(
copied_tensor
);
op
->
input_devices
[
i
]
=
copied_tensor
->
d
;
delete
s
;
}
if
(
op
->
inputs
[
i
].
dtype
()
!=
kernel
->
input_type
(
i
))
{
return
tensorflow
::
errors
::
InvalidArgument
(
...
...
@@ -468,10 +516,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
}
tensorflow
::
gtl
::
InsertOrUpdate
(
&
(
ctx
->
kernel_cache
),
cache_key
,
kernel
);
}
status
->
status
=
ValidateInputTypeAndPlacement
(
ctx
->
devices
()[
0
],
device
,
op
,
kernel
->
kernel
());
std
::
vector
<
TFE_TensorHandle
*>
copied_tensors
;
status
->
status
=
ValidateInputTypeAndPlacement
(
ctx
,
ctx
->
devices
()[
0
],
device
,
op
,
kernel
->
kernel
(),
&
copied_tensors
);
output_memory_types
=
&
kernel
->
kernel
()
->
output_memory_types
();
if
(
!
status
->
status
.
ok
())
{
for
(
auto
*
t
:
copied_tensors
)
{
TFE_DeleteTensorHandle
(
t
);
}
return
;
}
// WARNING: kernel->Run utilizes the FunctionLibraryRuntime
...
...
@@ -483,6 +535,9 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// sense for FunctionLibraryRuntime to ensure thread-safe access to
// FunctionLibraryDefinition?).
status
->
status
=
kernel
->
Run
(
&
op
->
inputs
,
&
outputs
);
for
(
auto
*
t
:
copied_tensors
)
{
TFE_DeleteTensorHandle
(
t
);
}
if
(
!
status
->
status
.
ok
())
return
;
*
num_retvals
=
std
::
min
<
int
>
(
*
num_retvals
,
outputs
.
size
());
for
(
int
i
=
0
;
i
<
*
num_retvals
;
++
i
)
{
...
...
tensorflow/c/eager/c_api.h
浏览文件 @
b3d5ec90
...
...
@@ -43,14 +43,46 @@ limitations under the License.
extern
"C"
{
#endif
typedef
struct
TFE_ContextOptions
TFE_ContextOptions
;
// Return a new options object.
TF_CAPI_EXPORT
extern
TFE_ContextOptions
*
TFE_NewContextOptions
();
// Set the config in TF_ContextOptions.options.
// config should be a serialized tensorflow.ConfigProto proto.
// If config was not parsed successfully as a ConfigProto, record the
// error information in *status.
TF_CAPI_EXPORT
extern
void
TFE_ContextOptionsSetConfig
(
TFE_ContextOptions
*
options
,
const
void
*
proto
,
size_t
proto_len
,
TF_Status
*
status
);
// Controls how to act when we try to run an operation on a given device but
// some input tensors are not on that device.
typedef
enum
TFE_ContextDevicePlacementPolicy
{
// The default: running operations with input tensors on the wrong device will
// fail.
TFE_DEVICE_PLACEMENT_EXPLICIT
=
0
,
// Copy the tensor to the right device but log a warning.
TFE_DEVICE_PLACEMENT_WARN
=
1
,
// Silently copy the tensor, which has a performance cost since the
// operation will be blocked till the copy completes.
TFE_DEVICE_PLACEMENT_SILENT
=
2
,
}
TFE_ContextDevicePlacementPolicy
;
TF_CAPI_EXPORT
extern
void
TFE_ContextOptionsSetDevicePlacementPolicy
(
TFE_ContextOptions
*
,
TFE_ContextDevicePlacementPolicy
);
// Destroy an options object.
TF_CAPI_EXPORT
extern
void
TFE_DeleteContextOptions
(
TFE_ContextOptions
*
);
// "Context" under which operations/functions are executed. It encapsulates
// things like the available devices, resource manager etc.
//
// TODO(ashankar): Merge with TF_Session?
typedef
struct
TFE_Context
TFE_Context
;
TF_CAPI_EXPORT
extern
TFE_Context
*
TFE_NewContext
(
const
TF_SessionOptions
*
opts
,
TF_Status
*
status
);
TF_CAPI_EXPORT
extern
TFE_Context
*
TFE_NewContext
(
const
TFE_ContextOptions
*
opts
,
TF_Status
*
status
);
TF_CAPI_EXPORT
extern
void
TFE_DeleteContext
(
TFE_Context
*
ctx
,
TF_Status
*
status
);
TF_CAPI_EXPORT
extern
TF_DeviceList
*
TFE_ContextListDevices
(
TFE_Context
*
ctx
,
TF_Status
*
status
);
...
...
tensorflow/c/eager/c_api_internal.h
浏览文件 @
b3d5ec90
...
...
@@ -35,9 +35,16 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
struct
TFE_ContextOptions
{
TF_SessionOptions
session_options
;
TFE_ContextDevicePlacementPolicy
policy
{
TFE_DEVICE_PLACEMENT_EXPLICIT
};
};
struct
TFE_Context
{
explicit
TFE_Context
(
TF_Session
*
s
)
:
session
(
s
)
{}
TFE_ContextDevicePlacementPolicy
policy
;
// TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
TF_Session
*
session
;
tensorflow
::
Rendezvous
*
rendezvous
;
...
...
tensorflow/c/eager/c_api_test.cc
浏览文件 @
b3d5ec90
...
...
@@ -62,10 +62,10 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
void
BM_InitOp
(
int
iters
)
{
tensorflow
::
testing
::
StopTiming
();
TF_Status
*
status
=
TF_NewStatus
();
TF
_SessionOptions
*
opts
=
TF_NewSession
Options
();
TF
E_ContextOptions
*
opts
=
TFE_NewContext
Options
();
TFE_Context
*
ctx
=
TFE_NewContext
(
opts
,
status
);
CHECK_EQ
(
TF_OK
,
TF_GetCode
(
status
))
<<
TF_Message
(
status
);
TF
_DeleteSession
Options
(
opts
);
TF
E_DeleteContext
Options
(
opts
);
TFE_TensorHandle
*
m
=
TestMatrixTensorHandle
();
tensorflow
::
testing
::
StartTiming
();
...
...
@@ -84,10 +84,10 @@ BENCHMARK(BM_InitOp);
void
BM_Execute
(
int
iters
)
{
tensorflow
::
testing
::
StopTiming
();
TF_Status
*
status
=
TF_NewStatus
();
TF
_SessionOptions
*
opts
=
TF_NewSession
Options
();
TF
E_ContextOptions
*
opts
=
TFE_NewContext
Options
();
TFE_Context
*
ctx
=
TFE_NewContext
(
opts
,
status
);
CHECK_EQ
(
TF_OK
,
TF_GetCode
(
status
))
<<
TF_Message
(
status
);
TF
_DeleteSession
Options
(
opts
);
TF
E_DeleteContext
Options
(
opts
);
TFE_TensorHandle
*
m
=
TestMatrixTensorHandle
();
TFE_Op
*
matmul
=
MatMulOp
(
ctx
,
m
,
m
);
...
...
@@ -109,9 +109,9 @@ BENCHMARK(BM_Execute);
TEST
(
CAPI
,
Context
)
{
TF_Status
*
status
=
TF_NewStatus
();
TF
_SessionOptions
*
opts
=
TF_NewSession
Options
();
TF
E_ContextOptions
*
opts
=
TFE_NewContext
Options
();
TFE_Context
*
ctx
=
TFE_NewContext
(
opts
,
status
);
TF
_DeleteSession
Options
(
opts
);
TF
E_DeleteContext
Options
(
opts
);
TF_DeviceList
*
devices
=
TFE_ContextListDevices
(
ctx
,
status
);
EXPECT_EQ
(
TF_OK
,
TF_GetCode
(
status
))
<<
TF_Message
(
status
);
...
...
@@ -150,9 +150,9 @@ TEST(CAPI, TensorHandle) {
TEST
(
CAPI
,
TensorHandleCopyBetweenDevices
)
{
std
::
unique_ptr
<
TF_Status
,
decltype
(
&
TF_DeleteStatus
)
>
status
(
TF_NewStatus
(),
TF_DeleteStatus
);
TF
_SessionOptions
*
opts
=
TF_NewSession
Options
();
TF
E_ContextOptions
*
opts
=
TFE_NewContext
Options
();
TFE_Context
*
ctx
=
TFE_NewContext
(
opts
,
status
.
get
());
TF
_DeleteSession
Options
(
opts
);
TF
E_DeleteContext
Options
(
opts
);
ASSERT_EQ
(
TF_OK
,
TF_GetCode
(
status
.
get
()))
<<
TF_Message
(
status
.
get
());
TFE_TensorHandle
*
hcpu
=
TestMatrixTensorHandle
();
...
...
@@ -216,12 +216,58 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) {
EXPECT_EQ
(
TF_OK
,
TF_GetCode
(
status
.
get
()))
<<
TF_Message
(
status
.
get
());
}
TEST
(
CAPI
,
TensorHandleSilentCopy
)
{
std
::
unique_ptr
<
TF_Status
,
decltype
(
&
TF_DeleteStatus
)
>
status
(
TF_NewStatus
(),
TF_DeleteStatus
);
TFE_ContextOptions
*
opts
=
TFE_NewContextOptions
();
TFE_ContextOptionsSetDevicePlacementPolicy
(
opts
,
TFE_DEVICE_PLACEMENT_SILENT
);
TFE_Context
*
ctx
=
TFE_NewContext
(
opts
,
status
.
get
());
TFE_DeleteContextOptions
(
opts
);
ASSERT_EQ
(
TF_OK
,
TF_GetCode
(
status
.
get
()))
<<
TF_Message
(
status
.
get
());
TFE_TensorHandle
*
hcpu
=
TestMatrixTensorHandle
();
TF_Tensor
*
t
=
TFE_TensorHandleResolve
(
hcpu
,
status
.
get
());
ASSERT_EQ
(
TF_OK
,
TF_GetCode
(
status
.
get
()))
<<
TF_Message
(
status
.
get
());
TF_DeviceList
*
devices
=
TFE_ContextListDevices
(
ctx
,
status
.
get
());
ASSERT_EQ
(
TF_OK
,
TF_GetCode
(
status
.
get
()))
<<
TF_Message
(
status
.
get
());
const
int
num_devices
=
TF_DeviceListCount
(
devices
);
// Disable the test if no GPU is present.
if
(
num_devices
>
1
)
{
const
int
device_to_use
=
1
;
const
string
name
(
TF_DeviceListName
(
devices
,
device_to_use
,
status
.
get
()));
ASSERT_TRUE
(
TF_GetCode
(
status
.
get
())
==
TF_OK
)
<<
TF_Message
(
status
.
get
());
TFE_TensorHandle
*
hgpu
=
TFE_TensorHandleCopyToDevice
(
hcpu
,
ctx
,
name
.
c_str
(),
status
.
get
());
ASSERT_TRUE
(
TF_GetCode
(
status
.
get
())
==
TF_OK
)
<<
TF_Message
(
status
.
get
());
TFE_Op
*
matmul
=
MatMulOp
(
ctx
,
hcpu
,
hgpu
);
TFE_OpSetDevice
(
matmul
,
name
.
c_str
(),
status
.
get
());
ASSERT_TRUE
(
TF_GetCode
(
status
.
get
())
==
TF_OK
)
<<
TF_Message
(
status
.
get
());
TFE_TensorHandle
*
retvals
[
1
];
int
num_retvals
=
1
;
TFE_Execute
(
matmul
,
&
retvals
[
0
],
&
num_retvals
,
status
.
get
());
ASSERT_TRUE
(
TF_GetCode
(
status
.
get
())
==
TF_OK
)
<<
TF_Message
(
status
.
get
());
TFE_DeleteOp
(
matmul
);
TFE_DeleteTensorHandle
(
retvals
[
0
]);
TFE_DeleteTensorHandle
(
hgpu
);
}
TF_DeleteDeviceList
(
devices
);
TF_DeleteTensor
(
t
);
TFE_DeleteTensorHandle
(
hcpu
);
TFE_DeleteContext
(
ctx
,
status
.
get
());
EXPECT_EQ
(
TF_OK
,
TF_GetCode
(
status
.
get
()))
<<
TF_Message
(
status
.
get
());
}
TEST
(
CAPI
,
Execute
)
{
TF_Status
*
status
=
TF_NewStatus
();
TF
_SessionOptions
*
opts
=
TF_NewSession
Options
();
TF
E_ContextOptions
*
opts
=
TFE_NewContext
Options
();
TFE_Context
*
ctx
=
TFE_NewContext
(
opts
,
status
);
CHECK_EQ
(
TF_OK
,
TF_GetCode
(
status
))
<<
TF_Message
(
status
);
TF
_DeleteSession
Options
(
opts
);
TF
E_DeleteContext
Options
(
opts
);
TFE_TensorHandle
*
m
=
TestMatrixTensorHandle
();
TFE_Op
*
matmul
=
MatMulOp
(
ctx
,
m
,
m
);
...
...
@@ -285,10 +331,10 @@ string MatMulFunction() {
TEST
(
CAPI
,
FunctionDefAndExecute
)
{
TF_Status
*
status
=
TF_NewStatus
();
TF
_SessionOptions
*
opts
=
TF_NewSession
Options
();
TF
E_ContextOptions
*
opts
=
TFE_NewContext
Options
();
TFE_Context
*
ctx
=
TFE_NewContext
(
opts
,
status
);
CHECK_EQ
(
TF_OK
,
TF_GetCode
(
status
))
<<
TF_Message
(
status
);
TF
_DeleteSession
Options
(
opts
);
TF
E_DeleteContext
Options
(
opts
);
string
function_def
=
MatMulFunction
();
TFE_ContextAddFunctionDef
(
ctx
,
function_def
.
data
(),
function_def
.
size
(),
...
...
@@ -326,10 +372,10 @@ TEST(CAPI, FunctionDefAndExecute) {
void
BM_ExecuteFunction
(
int
iters
)
{
tensorflow
::
testing
::
StopTiming
();
TF_Status
*
status
=
TF_NewStatus
();
TF
_SessionOptions
*
opts
=
TF_NewSession
Options
();
TF
E_ContextOptions
*
opts
=
TFE_NewContext
Options
();
TFE_Context
*
ctx
=
TFE_NewContext
(
opts
,
status
);
CHECK_EQ
(
TF_OK
,
TF_GetCode
(
status
))
<<
TF_Message
(
status
);
TF
_DeleteSession
Options
(
opts
);
TF
E_DeleteContext
Options
(
opts
);
string
function_def
=
MatMulFunction
();
TFE_ContextAddFunctionDef
(
ctx
,
function_def
.
data
(),
function_def
.
size
(),
...
...
@@ -406,10 +452,10 @@ TEST(CAPI, Variables) {
// Variables use resource handles, so this is really a test for resource
// tensor handling.
TF_Status
*
status
=
TF_NewStatus
();
TF
_SessionOptions
*
opts
=
TF_NewSession
Options
();
TF
E_ContextOptions
*
opts
=
TFE_NewContext
Options
();
TFE_Context
*
ctx
=
TFE_NewContext
(
opts
,
status
);
ASSERT_EQ
(
TF_OK
,
TF_GetCode
(
status
))
<<
TF_Message
(
status
);
TF
_DeleteSession
Options
(
opts
);
TF
E_DeleteContext
Options
(
opts
);
TFE_TensorHandle
*
var_handle
=
CreateVariable
(
ctx
,
12.0
,
status
);
ASSERT_EQ
(
TF_OK
,
TF_GetCode
(
status
))
<<
TF_Message
(
status
);
...
...
@@ -446,10 +492,10 @@ TEST(CAPI, Variables) {
void
BM_ReadVariable
(
int
iters
)
{
tensorflow
::
testing
::
StopTiming
();
TF_Status
*
status
=
TF_NewStatus
();
TF
_SessionOptions
*
opts
=
TF_NewSession
Options
();
TF
E_ContextOptions
*
opts
=
TFE_NewContext
Options
();
TFE_Context
*
ctx
=
TFE_NewContext
(
opts
,
status
);
CHECK_EQ
(
TF_OK
,
TF_GetCode
(
status
))
<<
TF_Message
(
status
);
TF
_DeleteSession
Options
(
opts
);
TF
E_DeleteContext
Options
(
opts
);
TFE_TensorHandle
*
var_handle
=
CreateVariable
(
ctx
,
5.0
,
status
);
CHECK_EQ
(
TF_OK
,
TF_GetCode
(
status
))
<<
TF_Message
(
status
);
...
...
tensorflow/compiler/xla/client/computation_builder.h
浏览文件 @
b3d5ec90
...
...
@@ -138,6 +138,11 @@ class ComputationBuilder {
ComputationDataHandle
ConstantR2
(
std
::
initializer_list
<
std
::
initializer_list
<
NativeT
>>
values
);
template
<
typename
NativeT
>
ComputationDataHandle
ConstantFromArrayWithLayout
(
const
Array
<
NativeT
>&
values
,
const
Layout
&
layout
);
template
<
typename
NativeT
>
ComputationDataHandle
ConstantFromArray
(
const
Array
<
NativeT
>&
values
);
template
<
typename
NativeT
>
ComputationDataHandle
ConstantR2FromArray2DWithLayout
(
const
Array2D
<
NativeT
>&
values
,
const
Layout
&
layout
);
template
<
typename
NativeT
>
...
...
@@ -910,48 +915,54 @@ ComputationDataHandle ComputationBuilder::ConstantR2(
}
template
<
typename
NativeT
>
ComputationDataHandle
ComputationBuilder
::
Constant
R2FromArray2D
WithLayout
(
const
Array
2D
<
NativeT
>&
values
,
const
Layout
&
layout
)
{
ComputationDataHandle
ComputationBuilder
::
Constant
FromArray
WithLayout
(
const
Array
<
NativeT
>&
values
,
const
Layout
&
layout
)
{
return
ConstantOp
([
&
values
,
&
layout
](
Literal
*
literal
)
{
literal
->
Populate
R2FromArray2D
WithLayout
(
values
,
layout
);
literal
->
Populate
FromArray
WithLayout
(
values
,
layout
);
});
}
template
<
typename
NativeT
>
ComputationDataHandle
ComputationBuilder
::
ConstantFromArray
(
const
Array
<
NativeT
>&
values
)
{
return
ConstantOp
(
[
&
values
](
Literal
*
literal
)
{
literal
->
PopulateFromArray
(
values
);
});
}
template
<
typename
NativeT
>
ComputationDataHandle
ComputationBuilder
::
ConstantR2FromArray2DWithLayout
(
const
Array2D
<
NativeT
>&
values
,
const
Layout
&
layout
)
{
return
ConstantFromArrayWithLayout
(
values
,
layout
);
}
template
<
typename
NativeT
>
ComputationDataHandle
ComputationBuilder
::
ConstantR2FromArray2D
(
const
Array2D
<
NativeT
>&
values
)
{
return
ConstantOp
(
[
&
values
](
Literal
*
literal
)
{
literal
->
PopulateR2FromArray2D
(
values
);
});
return
ConstantFromArray
(
values
);
}
template
<
typename
NativeT
>
ComputationDataHandle
ComputationBuilder
::
ConstantR3FromArray3DWithLayout
(
const
Array3D
<
NativeT
>&
values
,
const
Layout
&
layout
)
{
return
ConstantOp
([
&
values
,
&
layout
](
Literal
*
literal
)
{
literal
->
PopulateR3FromArray3DWithLayout
(
values
,
layout
);
});
return
ConstantFromArrayWithLayout
(
values
,
layout
);
}
template
<
typename
NativeT
>
ComputationDataHandle
ComputationBuilder
::
ConstantR3FromArray3D
(
const
Array3D
<
NativeT
>&
values
)
{
return
ConstantOp
(
[
&
values
](
Literal
*
literal
)
{
literal
->
PopulateR3FromArray3D
(
values
);
});
return
ConstantFromArray
(
values
);
}
template
<
typename
NativeT
>
ComputationDataHandle
ComputationBuilder
::
ConstantR4FromArray4DWithLayout
(
const
Array4D
<
NativeT
>&
values
,
const
Layout
&
layout
)
{
return
ConstantOp
([
&
values
,
&
layout
](
Literal
*
literal
)
{
literal
->
PopulateR4FromArray4DWithLayout
(
values
,
layout
);
});
return
ConstantFromArrayWithLayout
(
values
,
layout
);
}
template
<
typename
NativeT
>
ComputationDataHandle
ComputationBuilder
::
ConstantR4FromArray4D
(
const
Array4D
<
NativeT
>&
values
)
{
return
ConstantOp
(
[
&
values
](
Literal
*
literal
)
{
literal
->
PopulateR4FromArray4D
(
values
);
});
return
ConstantFromArray
(
values
);
}
}
// namespace xla
...
...
tensorflow/compiler/xla/layout_util.cc
浏览文件 @
b3d5ec90
...
...
@@ -83,6 +83,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
return
CreateDefaultLayoutForRank
(
shape
.
dimensions_size
());
}
/* static */
Layout
LayoutUtil
::
GetDefaultLayoutForRank
(
int64
rank
)
{
return
CreateDefaultLayoutForRank
(
rank
);
}
/* static */
Layout
LayoutUtil
::
GetDefaultLayoutForR2
()
{
return
CreateDefaultLayoutForRank
(
2
);
}
...
...
tensorflow/compiler/xla/layout_util.h
浏览文件 @
b3d5ec90
...
...
@@ -40,6 +40,7 @@ class LayoutUtil {
static
Layout
GetDefaultLayoutForShape
(
const
Shape
&
shape
);
// Helper functions that create default layouts for various ranks.
static
Layout
GetDefaultLayoutForRank
(
int64
rank
);
static
Layout
GetDefaultLayoutForR2
();
static
Layout
GetDefaultLayoutForR3
();
static
Layout
GetDefaultLayoutForR4
();
...
...
tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
浏览文件 @
b3d5ec90
...
...
@@ -206,9 +206,9 @@ void AllocateFlags() {
flag_values
->
xla_gpu_disable_multi_streaming
(),
"If true, multi-streaming in the GPU backend is disabled."
),
tensorflow
::
Flag
(
"xla_dump_
debug_json
_to"
,
flag_values
->
mutable_xla_dump_
debug_json
_to
(),
"Dump compilation artifacts as
JSON
into this directory."
),
"xla_dump_
hlo_proto
_to"
,
flag_values
->
mutable_xla_dump_
hlo_proto
_to
(),
"Dump compilation artifacts as
proto binary
into this directory."
),
tensorflow
::
Flag
(
"xla_test_all_output_layouts"
,
bool_setter_for
(
&
DebugOptions
::
set_xla_test_all_output_layouts
),
...
...
tensorflow/compiler/xla/literal_util.h
浏览文件 @
b3d5ec90
...
...
@@ -334,6 +334,11 @@ class Literal {
// WithLayout use the default XLA layout for the literal's linear
// representation in memory.
template
<
typename
NativeT
>
static
std
::
unique_ptr
<
Literal
>
CreateFromArray
(
const
Array
<
NativeT
>&
values
);
template
<
typename
NativeT
>
static
std
::
unique_ptr
<
Literal
>
CreateFromArrayWithLayout
(
const
Array
<
NativeT
>&
values
,
const
Layout
&
layout
);
template
<
typename
NativeT
>
static
std
::
unique_ptr
<
Literal
>
CreateR2FromArray2D
(
const
Array2D
<
NativeT
>&
values
);
template
<
typename
NativeT
>
...
...
@@ -481,6 +486,11 @@ class Literal {
std
::
initializer_list
<
std
::
initializer_list
<
NativeT
>>
values
,
const
Layout
&
layout
);
template
<
typename
NativeT
>
void
PopulateFromArray
(
const
Array
<
NativeT
>&
values
);
template
<
typename
NativeT
>
void
PopulateFromArrayWithLayout
(
const
Array
<
NativeT
>&
values
,
const
Layout
&
layout
);
template
<
typename
NativeT
>
void
PopulateR2FromArray2D
(
const
Array2D
<
NativeT
>&
values
);
template
<
typename
NativeT
>
void
PopulateR2FromArray2DWithLayout
(
const
Array2D
<
NativeT
>&
values
,
...
...
@@ -816,33 +826,42 @@ template <typename NativeT>
}
template
<
typename
NativeT
>
/* static */
std
::
unique_ptr
<
Literal
>
Literal
::
Create
R2FromArray2D
WithLayout
(
const
Array
2D
<
NativeT
>&
values
,
const
Layout
&
layout
)
{
/* static */
std
::
unique_ptr
<
Literal
>
Literal
::
Create
FromArray
WithLayout
(
const
Array
<
NativeT
>&
values
,
const
Layout
&
layout
)
{
auto
literal
=
MakeUnique
<
Literal
>
();
literal
->
Populate
R2FromArray2D
WithLayout
(
values
,
layout
);
literal
->
Populate
FromArray
WithLayout
(
values
,
layout
);
return
literal
;
}
template
<
typename
NativeT
>
/* static */
std
::
unique_ptr
<
Literal
>
Literal
::
CreateFromArray
(
const
Array
<
NativeT
>&
values
)
{
return
CreateFromArrayWithLayout
(
values
,
LayoutUtil
::
GetDefaultLayoutForRank
(
values
.
num_dimensions
()));
}
template
<
typename
NativeT
>
/* static */
std
::
unique_ptr
<
Literal
>
Literal
::
CreateR2FromArray2DWithLayout
(
const
Array2D
<
NativeT
>&
values
,
const
Layout
&
layout
)
{
return
CreateFromArrayWithLayout
(
values
,
layout
);
}
template
<
typename
NativeT
>
/* static */
std
::
unique_ptr
<
Literal
>
Literal
::
CreateR2FromArray2D
(
const
Array2D
<
NativeT
>&
values
)
{
return
CreateR2FromArray2DWithLayout
(
values
,
LayoutUtil
::
GetDefaultLayoutForR2
());
return
CreateFromArray
(
values
);
}
template
<
typename
NativeT
>
/* static */
std
::
unique_ptr
<
Literal
>
Literal
::
CreateR3FromArray3DWithLayout
(
const
Array3D
<
NativeT
>&
values
,
const
Layout
&
layout
)
{
auto
literal
=
MakeUnique
<
Literal
>
();
literal
->
PopulateR3FromArray3DWithLayout
(
values
,
layout
);
return
literal
;
return
CreateFromArrayWithLayout
(
values
,
layout
);
}
template
<
typename
NativeT
>
/* static */
std
::
unique_ptr
<
Literal
>
Literal
::
CreateR3FromArray3D
(
const
Array3D
<
NativeT
>&
values
)
{
return
CreateR3FromArray3DWithLayout
(
values
,
LayoutUtil
::
GetDefaultLayoutForR3
());
return
CreateFromArray
(
values
);
}
template
<
typename
NativeT
>
...
...
@@ -901,16 +920,13 @@ template <typename NativeT>
template
<
typename
NativeT
>
/* static */
std
::
unique_ptr
<
Literal
>
Literal
::
CreateR4FromArray4D
(
const
Array4D
<
NativeT
>&
values
)
{
return
CreateR4FromArray4DWithLayout
(
values
,
LayoutUtil
::
GetDefaultLayoutForR4
());
return
CreateFromArray
(
values
);
}
template
<
typename
NativeT
>
/* static */
std
::
unique_ptr
<
Literal
>
Literal
::
CreateR4FromArray4DWithLayout
(
const
Array4D
<
NativeT
>&
values
,
const
Layout
&
layout
)
{
auto
literal
=
MakeUnique
<
Literal
>
();
literal
->
PopulateR4FromArray4DWithLayout
(
values
,
layout
);
return
literal
;
return
CreateFromArrayWithLayout
(
values
,
layout
);
}
template
<
typename
NativeT
>
...
...
@@ -1070,82 +1086,53 @@ void Literal::PopulateR2(
}
template
<
typename
NativeT
>
void
Literal
::
Populate
R2FromArray2DWithLayout
(
const
Array2D
<
NativeT
>&
values
,
const
Layout
&
layout
)
{
void
Literal
::
Populate
FromArrayWithLayout
(
const
Array
<
NativeT
>&
values
,
const
Layout
&
layout
)
{
*
mutable_shape
()
=
ShapeUtil
::
MakeShapeWithLayout
(
primitive_util
::
NativeToPrimitiveType
<
NativeT
>
(),
{
values
.
height
(),
values
.
width
()},
AsInt64Slice
(
layout
.
minor_to_major
()));
primitive_util
::
NativeToPrimitiveType
<
NativeT
>
(),
values
.
dimensions
(),
AsInt64Slice
(
layout
.
minor_to_major
()));
Reserve
(
values
.
num_elements
());
values
.
Each
([
this
](
tensorflow
::
gtl
::
ArraySlice
<
int64
>
indices
,
NativeT
value
)
{
this
->
Set
(
indices
,
value
);
});
}
const
int64
dim1_size
=
values
.
width
();
const
int64
dim0_size
=
values
.
height
();
CHECK_EQ
(
dim0_size
,
shape
().
dimensions
(
0
));
CHECK_EQ
(
dim1_size
,
shape
().
dimensions
(
1
));
Reserve
(
dim1_size
*
dim0_size
);
for
(
int64
dim0
=
0
;
dim0
<
dim0_size
;
++
dim0
)
{
for
(
int64
dim1
=
0
;
dim1
<
dim1_size
;
++
dim1
)
{
Set
({
dim0
,
dim1
},
values
(
dim0
,
dim1
));
}
}
template
<
typename
NativeT
>
void
Literal
::
PopulateFromArray
(
const
Array
<
NativeT
>&
values
)
{
PopulateFromArrayWithLayout
(
values
,
LayoutUtil
::
GetDefaultLayoutForRank
(
values
.
num_dimensions
()
));
}
template
<
typename
NativeT
>
void
Literal
::
PopulateR2FromArray2DWithLayout
(
const
Array2D
<
NativeT
>&
values
,
const
Layout
&
layout
)
{
PopulateFromArrayWithLayout
(
values
,
layout
);
}
template
<
typename
NativeT
>
void
Literal
::
PopulateR2FromArray2D
(
const
Array2D
<
NativeT
>&
values
)
{
Populate
R2FromArray2DWithLayout
(
values
,
LayoutUtil
::
GetDefaultLayoutForR2
()
);
Populate
FromArray
(
values
);
}
template
<
typename
NativeT
>
void
Literal
::
PopulateR3FromArray3DWithLayout
(
const
Array3D
<
NativeT
>&
values
,
const
Layout
&
layout
)
{
*
mutable_shape
()
=
ShapeUtil
::
MakeShapeWithLayout
(
primitive_util
::
NativeToPrimitiveType
<
NativeT
>
(),
{
values
.
n1
(),
values
.
n2
(),
values
.
n3
()},
AsInt64Slice
(
layout
.
minor_to_major
()));
CHECK_EQ
(
values
.
n1
(),
shape
().
dimensions
(
0
));
CHECK_EQ
(
values
.
n2
(),
shape
().
dimensions
(
1
));
CHECK_EQ
(
values
.
n3
(),
shape
().
dimensions
(
2
));
Reserve
(
values
.
n1
()
*
values
.
n2
()
*
values
.
n3
());
for
(
int64
dim0
=
0
;
dim0
<
values
.
n1
();
++
dim0
)
{
for
(
int64
dim1
=
0
;
dim1
<
values
.
n2
();
++
dim1
)
{
for
(
int64
dim2
=
0
;
dim2
<
values
.
n3
();
++
dim2
)
{
Set
({
dim0
,
dim1
,
dim2
},
values
(
dim0
,
dim1
,
dim2
));
}
}
}
PopulateFromArrayWithLayout
(
values
,
layout
);
}
template
<
typename
NativeT
>
void
Literal
::
PopulateR3FromArray3D
(
const
Array3D
<
NativeT
>&
values
)
{
Populate
R3FromArray3DWithLayout
(
values
,
LayoutUtil
::
GetDefaultLayoutForR3
()
);
Populate
FromArray
(
values
);
}
template
<
typename
NativeT
>
void
Literal
::
PopulateR4FromArray4DWithLayout
(
const
Array4D
<
NativeT
>&
values
,
const
Layout
&
layout
)
{
*
mutable_shape
()
=
ShapeUtil
::
MakeShapeWithLayout
(
primitive_util
::
NativeToPrimitiveType
<
NativeT
>
(),
{
values
.
planes
(),
values
.
depth
(),
values
.
height
(),
values
.
width
()},
AsInt64Slice
(
layout
.
minor_to_major
()));
CHECK_EQ
(
values
.
n1
(),
shape
().
dimensions
(
0
));
CHECK_EQ
(
values
.
n2
(),
shape
().
dimensions
(
1
));
CHECK_EQ
(
values
.
n3
(),
shape
().
dimensions
(
2
));
CHECK_EQ
(
values
.
n4
(),
shape
().
dimensions
(
3
));
Reserve
(
values
.
n1
()
*
values
.
n2
()
*
values
.
n3
()
*
values
.
n4
());
for
(
int64
dim0
=
0
;
dim0
<
values
.
n1
();
++
dim0
)
{
for
(
int64
dim1
=
0
;
dim1
<
values
.
n2
();
++
dim1
)
{
for
(
int64
dim2
=
0
;
dim2
<
values
.
n3
();
++
dim2
)
{
for
(
int64
dim3
=
0
;
dim3
<
values
.
n4
();
++
dim3
)
{
Set
({
dim0
,
dim1
,
dim2
,
dim3
},
values
(
dim0
,
dim1
,
dim2
,
dim3
));
}
}
}
}
PopulateFromArrayWithLayout
(
values
,
layout
);
}
template
<
typename
NativeT
>
void
Literal
::
PopulateR4FromArray4D
(
const
Array4D
<
NativeT
>&
values
)
{
Populate
R4FromArray4DWithLayout
(
values
,
LayoutUtil
::
GetDefaultLayoutForR4
()
);
Populate
FromArray
(
values
);
}
template
<
typename
NativeT
,
typename
FnType
>
...
...
tensorflow/compiler/xla/protobuf_util.cc
浏览文件 @
b3d5ec90
...
...
@@ -37,20 +37,6 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
return
(
serialized1
==
serialized2
);
}
StatusOr
<
string
>
ToJson
(
const
tensorflow
::
protobuf
::
Message
&
message
)
{
string
json_output
;
tensorflow
::
protobuf
::
util
::
JsonPrintOptions
json_options
;
json_options
.
add_whitespace
=
true
;
json_options
.
always_print_primitive_fields
=
true
;
auto
status
=
tensorflow
::
protobuf
::
util
::
MessageToJsonString
(
message
,
&
json_output
,
json_options
);
if
(
!
status
.
ok
())
{
return
InternalError
(
"MessageToJsonString failed: %s"
,
status
.
error_message
().
data
());
}
return
json_output
;
}
namespace
{
string
SanitizeFilename
(
const
string
&
file_name
)
{
...
...
@@ -65,17 +51,6 @@ string SanitizeFilename(const string& file_name) {
}
// namespace
Status
DumpJsonToDirectory
(
const
tensorflow
::
protobuf
::
Message
&
message
,
const
string
&
directory
,
const
string
&
file_name
)
{
TF_ASSIGN_OR_RETURN
(
const
string
json_output
,
ToJson
(
message
));
tensorflow
::
Env
*
env
=
tensorflow
::
Env
::
Default
();
TF_RETURN_IF_ERROR
(
env
->
RecursivelyCreateDir
(
directory
));
string
safe_file_name
=
SanitizeFileName
(
file_name
)
+
".json"
;
const
string
path
=
tensorflow
::
io
::
JoinPath
(
directory
,
safe_file_name
);
return
tensorflow
::
WriteStringToFile
(
env
,
path
,
json_output
);
}
Status
DumpProtoToDirectory
(
const
tensorflow
::
protobuf
::
Message
&
message
,
const
string
&
directory
,
const
string
&
file_name
)
{
tensorflow
::
Env
*
env
=
tensorflow
::
Env
::
Default
();
...
...
tensorflow/compiler/xla/protobuf_util.h
浏览文件 @
b3d5ec90
...
...
@@ -32,17 +32,12 @@ namespace protobuf_util {
extern
bool
ProtobufEquals
(
const
tensorflow
::
protobuf
::
Message
&
m1
,
const
tensorflow
::
protobuf
::
Message
&
m2
);
// Returns 'message' as a JSON string.
StatusOr
<
string
>
ToJson
(
const
tensorflow
::
protobuf
::
Message
&
message
);
// Writes the given message in binary proto or JSON format to the path formed by
// joining 'directory/file_name.pb' (or file_name.json). The 'directory' is
// recursively created if it doesn't already exist, and the 'file_name' is
// sanitized by replacing illegal characters with underscore '_'.
// Writes the given message in binary proto to the path formed by joining
// 'directory/file_name.pb'. The 'directory' is recursively created if it
// doesn't already exist, and the 'file_name' is sanitized by replacing
// illegal characters with underscore '_'.
Status
DumpProtoToDirectory
(
const
tensorflow
::
protobuf
::
Message
&
message
,
const
string
&
directory
,
const
string
&
file_name
);
Status
DumpJsonToDirectory
(
const
tensorflow
::
protobuf
::
Message
&
message
,
const
string
&
directory
,
const
string
&
file_name
);
}
// namespace protobuf_util
}
// namespace xla
...
...
tensorflow/compiler/xla/service/BUILD
浏览文件 @
b3d5ec90
...
...
@@ -2064,6 +2064,29 @@ tf_cc_test(
],
)
cc_library
(
name
=
"hlo_runner"
,
srcs
=
[
"hlo_runner.cc"
],
hdrs
=
[
"hlo_runner.h"
],
deps
=
[
":executable"
,
":hlo"
,
":transfer_manager"
,
"//tensorflow/compiler/xla:shape_util"
,
"//tensorflow/compiler/xla:status_macros"
,
"//tensorflow/compiler/xla:statusor"
,
"//tensorflow/compiler/xla:types"
,
"//tensorflow/compiler/xla:util"
,
"//tensorflow/compiler/xla:xla_data_proto"
,
"//tensorflow/compiler/xla/service:backend"
,
"//tensorflow/compiler/xla/service:compiler"
,
"//tensorflow/core:core_cpu_internal"
,
"//tensorflow/core:lib"
,
"//tensorflow/core:stream_executor_no_cuda"
,
"//third_party/eigen3"
,
],
)
# -----------------------------------------------------------------------------
filegroup
(
...
...
tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
浏览文件 @
b3d5ec90
...
...
@@ -475,8 +475,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
// ownership is std::moved.
const
bool
embed_ir_in_executable
=
module
->
config
().
debug_options
().
xla_embed_ir_in_executable
();
const
string
dump_debug_json
_to
=
module
->
config
().
debug_options
().
xla_dump_
debug_json
_to
();
const
string
xla_dump_hlo_proto
_to
=
module
->
config
().
debug_options
().
xla_dump_
hlo_proto
_to
();
if
(
options
::
CpuParallelBackendRequested
(
module
->
config
()))
{
VLOG
(
1
)
<<
"Using parallel cpu backend"
;
...
...
@@ -496,10 +496,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
// print one ourselves.
XLA_VLOG_LINES
(
2
,
assignment
->
ToString
());
if
(
!
dump_debug_json
_to
.
empty
())
{
if
(
!
xla_dump_hlo_proto
_to
.
empty
())
{
HloProto
proto
=
MakeHloProto
(
*
module
,
*
assignment
);
TF_RETURN_IF_ERROR
(
protobuf_util
::
Dump
Json
ToDirectory
(
proto
,
dump_debug_json
_to
,
module
->
name
()));
TF_RETURN_IF_ERROR
(
protobuf_util
::
Dump
Proto
ToDirectory
(
proto
,
xla_dump_hlo_proto
_to
,
module
->
name
()));
}
// If we are using the parallel CPU backend, we need to create map from
...
...
@@ -603,12 +603,11 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
// print one ourselves.
XLA_VLOG_LINES
(
2
,
assignment
->
ToString
());
if
(
!
dump_debug_json
_to
.
empty
())
{
if
(
!
xla_dump_hlo_proto
_to
.
empty
())
{
HloProto
proto
=
MakeHloProto
(
*
module
,
*
assignment
);
TF_RETURN_IF_ERROR
(
protobuf_util
::
Dump
Json
ToDirectory
(
proto
,
dump_debug_json
_to
,
module
->
name
()));
TF_RETURN_IF_ERROR
(
protobuf_util
::
Dump
Proto
ToDirectory
(
proto
,
xla_dump_hlo_proto
_to
,
module
->
name
()));
}
// Each computation is a single function. Emit all embedded computations
// before the entry computation. The order of computations returned from
// GetEmbeddedComputations guarantees that a called computation occurs
...
...
@@ -775,12 +774,12 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
// print one ourselves.
XLA_VLOG_LINES
(
2
,
assignment
->
ToString
());
const
string
dump_debug_json
_to
=
module
->
config
().
debug_options
().
xla_dump_
debug_json
_to
();
if
(
!
dump_debug_json
_to
.
empty
())
{
const
string
xla_dump_hlo_proto
_to
=
module
->
config
().
debug_options
().
xla_dump_
hlo_proto
_to
();
if
(
!
xla_dump_hlo_proto
_to
.
empty
())
{
HloProto
proto
=
MakeHloProto
(
*
module
,
*
assignment
);
TF_RETURN_IF_ERROR
(
protobuf_util
::
Dump
Json
ToDirectory
(
proto
,
dump_debug_json
_to
,
module
->
name
()));
TF_RETURN_IF_ERROR
(
protobuf_util
::
Dump
Proto
ToDirectory
(
proto
,
xla_dump_hlo_proto
_to
,
module
->
name
()));
}
IrEmitter
ir_emitter
(
*
module
,
*
assignment
,
&
llvm_module
,
...
...
tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
浏览文件 @
b3d5ec90
...
...
@@ -136,6 +136,8 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
instruction
->
opcode
()
==
HloOpcode
::
kCall
||
instruction
->
opcode
()
==
HloOpcode
::
kCustomCall
||
instruction
->
opcode
()
==
HloOpcode
::
kSelectAndScatter
||
instruction
->
opcode
()
==
HloOpcode
::
kGetTupleElement
||
instruction
->
opcode
()
==
HloOpcode
::
kBitcast
||
(
instruction
->
opcode
()
==
HloOpcode
::
kConvolution
&&
PotentiallyImplementedAsEigenConvolution
(
*
instruction
))
||
PotentiallyImplementedAsEigenDot
(
*
instruction
)
||
...
...
tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
浏览文件 @
b3d5ec90
...
...
@@ -318,12 +318,12 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
// print one ourselves.
XLA_VLOG_LINES
(
2
,
buffer_assignment
->
ToString
());
const
string
dump_debug_json
_to
=
module
->
config
().
debug_options
().
xla_dump_
debug_json
_to
();
if
(
!
dump_debug_json
_to
.
empty
())
{
const
string
xla_dump_hlo_proto
_to
=
module
->
config
().
debug_options
().
xla_dump_
hlo_proto
_to
();
if
(
!
xla_dump_hlo_proto
_to
.
empty
())
{
HloProto
proto
=
MakeHloProto
(
*
module
,
*
buffer_assignment
);
TF_RETURN_IF_ERROR
(
protobuf_util
::
Dump
Json
ToDirectory
(
proto
,
dump_debug_json
_to
,
module
->
name
()));
TF_RETURN_IF_ERROR
(
protobuf_util
::
Dump
Proto
ToDirectory
(
proto
,
xla_dump_hlo_proto
_to
,
module
->
name
()));
}
IrEmitterContext
ir_emitter_context
(
module
.
get
(),
buffer_assignment
.
get
(),
...
...
tensorflow/compiler/xla/service/hlo_computation.cc
浏览文件 @
b3d5ec90
...
...
@@ -373,8 +373,8 @@ string HloComputation::ToString(int nested_level) const {
for
(
int
i
=
0
;
i
<
nested_level
;
i
++
)
{
s
<<
" "
;
}
s
<<
name
()
<<
" "
<<
ShapeUtil
::
HumanString
(
ComputeProgramShape
())
<<
" {
\n
"
;
s
<<
"%"
<<
name
()
<<
" "
<<
ShapeUtil
::
HumanString
(
ComputeProgramShape
())
<<
" {
\n
"
;
for
(
const
HloInstruction
*
instruction
:
MakeInstructionPostOrder
())
{
for
(
int
i
=
0
;
i
<
nested_level
;
i
++
)
{
s
<<
" "
;
...
...
tensorflow/compiler/xla/service/hlo_runner.cc
0 → 100644
浏览文件 @
b3d5ec90
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_runner.h"
#include <set>
#include <string>
#include <utility>
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace
se
=
::
perftools
::
gputools
;
namespace
xla
{
/*static*/
StatusOr
<
std
::
unique_ptr
<
HloModule
>>
HloRunner
::
ReadModuleFromHloProtoFile
(
const
char
*
filename
,
const
DebugOptions
&
debug_options
)
{
HloProto
proto
;
TF_RETURN_IF_ERROR
(
tensorflow
::
ReadBinaryProto
(
tensorflow
::
Env
::
Default
(),
filename
,
&
proto
));
HloModuleConfig
config
;
config
.
set_debug_options
(
debug_options
);
TF_ASSIGN_OR_RETURN
(
auto
module
,
HloModule
::
CreateFromProto
(
proto
.
hlo_module
(),
VersionedComputationHandle
(),
config
));
return
std
::
move
(
module
);
}
// Define this in .cc file to avoid having to include eigen or forward declare
// these types in the header.
struct
HloRunner
::
EigenThreadPoolWrapper
{
std
::
unique_ptr
<
EigenThreadPoolWrapper
>
pool
;
std
::
unique_ptr
<
Eigen
::
ThreadPoolDevice
>
device
;
};
HloRunner
::
HloRunner
()
{}
HloRunner
::
HloRunner
(
se
::
Platform
*
platform
)
{
BackendOptions
backend_options
;
backend_options
.
set_platform
(
platform
);
backend_
=
Backend
::
CreateBackend
(
backend_options
).
ConsumeValueOrDie
();
VLOG
(
1
)
<<
"Created HloRunner for platform: "
<<
platform
->
Name
();
}
HloRunner
::~
HloRunner
()
{
// Deallocate all the memory allocated during the tests.
for
(
auto
&
allocation
:
allocations_
)
{
backend
().
default_stream_executor
()
->
Deallocate
(
&
allocation
);
}
}
StatusOr
<
se
::
DeviceMemoryBase
>
HloRunner
::
Execute
(
std
::
unique_ptr
<
HloModule
>
module
,
tensorflow
::
gtl
::
ArraySlice
<
se
::
DeviceMemoryBase
>
arguments
,
Shape
*
result_shape
)
{
TF_ASSIGN_OR_RETURN
(
std
::
unique_ptr
<
Executable
>
executable
,
backend
().
compiler
()
->
Compile
(
std
::
move
(
module
),
backend
().
default_stream_executor
()));
se
::
Stream
stream
(
backend
().
default_stream_executor
());
stream
.
Init
();
ExecutableRunOptions
run_options
;
run_options
.
set_stream
(
&
stream
);
run_options
.
set_allocator
(
backend
().
memory_allocator
());
run_options
.
set_inter_op_thread_pool
(
backend
().
inter_op_thread_pool
());
run_options
.
set_intra_op_thread_pool
(
backend
().
eigen_intra_op_thread_pool_device
());
HloExecutionProfile
hlo_execution_profile
;
ServiceExecutableRunOptions
service_run_options
(
run_options
,
backend
().
StreamBorrower
(),
backend
().
inter_op_thread_pool
());
TF_ASSIGN_OR_RETURN
(
se
::
DeviceMemoryBase
result
,
executable
->
ExecuteOnStream
(
&
service_run_options
,
arguments
,
&
hlo_execution_profile
));
TF_RET_CHECK
(
stream
.
BlockHostUntilDone
());
allocations_
.
push_back
(
result
);
*
result_shape
=
executable
->
result_shape
();
if
(
ShapeUtil
::
IsTuple
(
*
result_shape
))
{
// We must record element buffers of tuples as well to avoid leaks.
DCHECK
(
!
ShapeUtil
::
IsNestedTuple
(
*
result_shape
));
TF_ASSIGN_OR_RETURN
(
std
::
vector
<
se
::
DeviceMemoryBase
>
element_buffers
,
backend
().
transfer_manager
()
->
ShallowCopyTupleFromDevice
(
backend
().
default_stream_executor
(),
result
,
*
result_shape
));
// A tuple may contain the same buffer in more than one element. Keep track
// of the buffers already added to avoid duplicates in allocations_.
std
::
set
<
void
*>
added_opaques
;
for
(
auto
element_buffer
:
element_buffers
)
{
if
(
added_opaques
.
count
(
element_buffer
.
opaque
())
==
0
)
{
CHECK
(
element_buffer
.
opaque
()
!=
nullptr
);
added_opaques
.
insert
(
element_buffer
.
opaque
());
allocations_
.
push_back
(
element_buffer
);
}
}
}
return
result
;
}
se
::
DeviceMemoryBase
HloRunner
::
TransferToDevice
(
const
Literal
&
literal
)
{
// Allocate memory on the device using the stream executor.
int64
allocation_size
=
backend
().
transfer_manager
()
->
GetByteSizeRequirement
(
literal
.
shape
());
se
::
DeviceMemoryBase
allocation
=
backend
().
default_stream_executor
()
->
AllocateArray
<
uint8
>
(
allocation_size
);
allocations_
.
push_back
(
allocation
);
TF_CHECK_OK
(
backend
().
transfer_manager
()
->
TransferLiteralToDevice
(
backend
().
default_stream_executor
(),
literal
,
&
allocation
));
return
allocation
;
}
std
::
unique_ptr
<
Literal
>
HloRunner
::
TransferFromDevice
(
const
Shape
&
shape
,
se
::
DeviceMemoryBase
device_base
)
{
auto
literal
=
MakeUnique
<
Literal
>
();
TF_CHECK_OK
(
backend
().
transfer_manager
()
->
TransferLiteralFromDevice
(
backend
().
default_stream_executor
(),
device_base
,
shape
,
shape
,
literal
.
get
()));
return
literal
;
}
std
::
unique_ptr
<
Literal
>
HloRunner
::
ExecuteAndTransfer
(
std
::
unique_ptr
<
HloModule
>
module
,
tensorflow
::
gtl
::
ArraySlice
<
se
::
DeviceMemoryBase
>
arguments
)
{
Shape
result_shape
;
se
::
DeviceMemoryBase
device_base
=
Execute
(
std
::
move
(
module
),
arguments
,
&
result_shape
).
ValueOrDie
();
return
TransferFromDevice
(
result_shape
,
device_base
);
}
template
<
>
std
::
unique_ptr
<
Literal
>
HloRunner
::
Execute
(
std
::
unique_ptr
<
HloModule
>
module
,
const
tensorflow
::
gtl
::
ArraySlice
<
std
::
unique_ptr
<
Literal
>>&
literals
)
{
std
::
vector
<
se
::
DeviceMemoryBase
>
arguments
;
for
(
const
auto
&
literal
:
literals
)
{
arguments
.
push_back
(
TransferToDevice
(
*
literal
));
}
return
ExecuteAndTransfer
(
std
::
move
(
module
),
arguments
);
}
template
<
>
std
::
unique_ptr
<
Literal
>
HloRunner
::
Execute
(
std
::
unique_ptr
<
HloModule
>
module
,
const
tensorflow
::
gtl
::
ArraySlice
<
Literal
*>&
literals
)
{
std
::
vector
<
se
::
DeviceMemoryBase
>
arguments
;
for
(
const
auto
&
literal
:
literals
)
{
arguments
.
push_back
(
TransferToDevice
(
*
literal
));
}
return
ExecuteAndTransfer
(
std
::
move
(
module
),
arguments
);
}
Backend
&
HloRunner
::
backend
()
{
if
(
!
backend_
)
{
backend_
=
Backend
::
CreateDefaultBackend
().
ConsumeValueOrDie
();
VLOG
(
1
)
<<
"executing on platform "
<<
backend
().
platform
()
->
Name
();
}
return
*
backend_
;
}
}
// namespace xla
tensorflow/compiler/xla/service/hlo_runner.h
0 → 100644
浏览文件 @
b3d5ec90
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace
xla
{
// A base class for running an HloModule. This executes the given HloModule on a
// certain backend directly without using the client interface. HloModule can be
// explicitly built, or loaded from a serialization file (e.g., hlo proto file).
class
HloRunner
{
public:
HloRunner
();
HloRunner
(
::
perftools
::
gputools
::
Platform
*
platform
);
~
HloRunner
();
// Reads the binary proto file in xla.HloProto format, creates and returns the
// HloModule.
static
StatusOr
<
std
::
unique_ptr
<
HloModule
>>
ReadModuleFromHloProtoFile
(
const
char
*
filename
,
const
DebugOptions
&
debug_options
);
// Executes the given module with given literals as input and returns the
// result as a Literal. The LiteralPtr type accepts Literal* or
// std::unique_ptr<Literal>.
template
<
typename
LiteralPtr
>
std
::
unique_ptr
<
Literal
>
Execute
(
std
::
unique_ptr
<
HloModule
>
module
,
const
tensorflow
::
gtl
::
ArraySlice
<
LiteralPtr
>&
literals
);
// Executes the given module and returns a global data handle.
StatusOr
<
perftools
::
gputools
::
DeviceMemoryBase
>
Execute
(
std
::
unique_ptr
<
HloModule
>
module
,
tensorflow
::
gtl
::
ArraySlice
<
perftools
::
gputools
::
DeviceMemoryBase
>
arguments
,
Shape
*
result_shape
);
// Transfers the given literal to the device and returns the data handle.
perftools
::
gputools
::
DeviceMemoryBase
TransferToDevice
(
const
Literal
&
literal
);
// Transfers the array referred to by the given handle from the device and
// returns as a Literal.
std
::
unique_ptr
<
Literal
>
TransferFromDevice
(
const
Shape
&
shape
,
perftools
::
gputools
::
DeviceMemoryBase
device_base
);
// Executes the given module and return the result as a Literal.
std
::
unique_ptr
<
Literal
>
ExecuteAndTransfer
(
std
::
unique_ptr
<
HloModule
>
module
,
tensorflow
::
gtl
::
ArraySlice
<
perftools
::
gputools
::
DeviceMemoryBase
>
arguments
);
// If backend is not created in the constructor, creates and returns the
// default backend. If creation fails, crashes the program.
//
// This creates the backend lazily so it's possible to instantiate an
// HloRunner in a program without any backends linked in.
Backend
&
backend
();
private:
struct
EigenThreadPoolWrapper
;
std
::
vector
<
perftools
::
gputools
::
DeviceMemoryBase
>
allocations_
;
std
::
unique_ptr
<
EigenThreadPoolWrapper
>
thread_pool_wrapper_
;
std
::
unique_ptr
<
Backend
>
backend_
;
};
}
// namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
tensorflow/compiler/xla/service/transpose_folding.cc
浏览文件 @
b3d5ec90
...
...
@@ -58,14 +58,32 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution(
return
{};
}
// We only support folding the RHS.
const
int64
kRhsOperandIndex
=
1
;
auto
&
operand
=
*
convolution
.
operand
(
kRhsOperandIndex
);
if
(
operand
.
opcode
()
==
HloOpcode
::
kTranspose
&&
operand
.
user_count
()
==
1
)
{
return
transposable_conv_operands
(
convolution
,
{
kRhsOperandIndex
});
const
ConvolutionDimensionNumbers
&
dnums
=
convolution
.
convolution_dimension_numbers
();
TransposeFolding
::
OperandIndices
operand_set
;
for
(
int64
i
=
0
;
i
<
convolution
.
operand_count
();
++
i
)
{
auto
&
operand
=
*
convolution
.
operand
(
i
);
if
(
operand
.
opcode
()
==
HloOpcode
::
kTranspose
&&
operand
.
user_count
()
==
1
)
{
const
auto
&
transpose_dimensions
=
operand
.
dimensions
();
// We can transpose the LHS so long as it doesn't move around spatial
// dimensions because ConvolutionDimensionNumbers doesn't have different
// fields for input and output spatial dimensions.
if
(
i
==
0
&&
std
::
any_of
(
dnums
.
spatial_dimensions
().
begin
(),
dnums
.
spatial_dimensions
().
end
(),
[
&
](
const
int64
spatial_dimension
)
{
return
transpose_dimensions
[
spatial_dimension
]
!=
spatial_dimension
;
}))
{
continue
;
}
operand_set
.
push_back
(
i
);
}
}
return
{}
;
return
transposable_conv_operands
(
convolution
,
operand_set
)
;
}
using
InstructionOperandsPair
=
...
...
@@ -98,40 +116,61 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair) {
// Returns whether the module is changed.
bool
FoldTransposeIntoConvolution
(
InstructionOperandsPair
pair
)
{
auto
&
convolution
=
*
pair
.
first
;
// We only support fusing the RHS transpose into convolution.
//
// ConvolutionDimensionNumbers doesn't make enough of a distinction between
// the output and the activations.
//
// TODO(b/37125184): Support transposing the LHS too.
if
(
pair
.
second
.
size
()
!=
1
||
pair
.
second
.
front
()
!=
1
)
{
return
false
;
}
auto
&
operand_indices
=
pair
.
second
;
const
ConvolutionDimensionNumbers
&
dnums
=
convolution
.
convolution_dimension_numbers
();
HloInstruction
&
transpose
=
*
convolution
.
mutable_operand
(
1
);
CHECK_EQ
(
transpose
.
opcode
(),
HloOpcode
::
kTranspose
);
const
auto
&
transpose_dimensions
=
transpose
.
dimensions
();
HloInstruction
&
transpose_operand
=
*
transpose
.
mutable_operand
(
0
);
// Everything remains the same except for the kernel dimension numbers. We
// need to apply the transpose permutation to the original shape to figure out
// what the new logical dimensions are.
ConvolutionDimensionNumbers
new_dnums
=
dnums
;
new_dnums
.
set_kernel_input_feature_dimension
(
transpose_dimensions
[
dnums
.
kernel_input_feature_dimension
()]);
new_dnums
.
set_kernel_output_feature_dimension
(
transpose_dimensions
[
dnums
.
kernel_output_feature_dimension
()]);
for
(
auto
&
kernel_spatial_dimension
:
*
new_dnums
.
mutable_kernel_spatial_dimensions
())
{
kernel_spatial_dimension
=
transpose_dimensions
[
kernel_spatial_dimension
];
HloInstruction
*
new_lhs
;
const
int64
kLhsIdx
=
0
;
if
(
std
::
find
(
operand_indices
.
begin
(),
operand_indices
.
end
(),
kLhsIdx
)
!=
operand_indices
.
end
())
{
HloInstruction
&
transpose
=
*
convolution
.
mutable_operand
(
kLhsIdx
);
const
auto
&
transpose_dimensions
=
transpose
.
dimensions
();
HloInstruction
&
transpose_operand
=
*
transpose
.
mutable_operand
(
0
);
// Everything remains the same except for the input/output dimension
// numbers. We need to apply the transpose permutation to the original shape
// to figure out what the new logical dimensions are.
new_dnums
.
set_input_batch_dimension
(
transpose_dimensions
[
dnums
.
input_batch_dimension
()]);
new_dnums
.
set_input_feature_dimension
(
transpose_dimensions
[
dnums
.
input_feature_dimension
()]);
for
(
const
auto
&
spatial_dimension
:
dnums
.
spatial_dimensions
())
{
CHECK_EQ
(
spatial_dimension
,
transpose_dimensions
[
spatial_dimension
]);
}
new_lhs
=
&
transpose_operand
;
}
else
{
new_lhs
=
convolution
.
mutable_operand
(
kLhsIdx
);
}
HloInstruction
*
new_rhs
;
const
int64
kRhsIdx
=
1
;
if
(
std
::
find
(
operand_indices
.
begin
(),
operand_indices
.
end
(),
kRhsIdx
)
!=
operand_indices
.
end
())
{
HloInstruction
&
transpose
=
*
convolution
.
mutable_operand
(
kRhsIdx
);
const
auto
&
transpose_dimensions
=
transpose
.
dimensions
();
HloInstruction
&
transpose_operand
=
*
transpose
.
mutable_operand
(
0
);
// Everything remains the same except for the kernel dimension numbers. We
// need to apply the transpose permutation to the original shape to figure
// out what the new logical dimensions are.
new_dnums
.
set_kernel_input_feature_dimension
(
transpose_dimensions
[
dnums
.
kernel_input_feature_dimension
()]);
new_dnums
.
set_kernel_output_feature_dimension
(
transpose_dimensions
[
dnums
.
kernel_output_feature_dimension
()]);
for
(
auto
&
kernel_spatial_dimension
:
*
new_dnums
.
mutable_kernel_spatial_dimensions
())
{
kernel_spatial_dimension
=
transpose_dimensions
[
kernel_spatial_dimension
];
}
new_rhs
=
&
transpose_operand
;
}
else
{
new_rhs
=
convolution
.
mutable_operand
(
kRhsIdx
);
}
auto
new_conv
=
HloInstruction
::
CreateConvolve
(
convolution
.
shape
(),
convolution
.
mutable_operand
(
0
),
&
transpose_operand
,
convolution
.
window
(),
new_dnums
);
convolution
.
shape
(),
new_lhs
,
new_rhs
,
convolution
.
window
(),
new_dnums
);
TF_CHECK_OK
(
convolution
.
parent
()
->
ReplaceWithNewInstruction
(
&
convolution
,
std
::
move
(
new_conv
)));
...
...
tensorflow/compiler/xla/service/transpose_folding_test.cc
浏览文件 @
b3d5ec90
...
...
@@ -313,8 +313,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
new_conv
->
convolution_dimension_numbers
().
kernel_spatial_dimensions
(
1
));
}
// Test that a transpose of the activations does not get folded into
// convolution.
// Test that a transpose of the activations gets folded into convolution.
TEST_F
(
TransposeFoldingTest
,
FoldConvTransposeLhs
)
{
auto
builder
=
HloComputation
::
Builder
(
"entry_computation"
);
HloInstruction
*
x
=
builder
.
AddInstruction
(
HloInstruction
::
CreateParameter
(
...
...
@@ -348,18 +347,25 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
module
.
AddEntryComputation
(
builder
.
Build
(
conv
));
FoldTranspose
(
&
module
);
// Instructions after folding:
transpose_
x, y, and the convolution.
// Instructions after folding: x, y, and the convolution.
std
::
unordered_set
<
HloInstruction
*>
instruction_set
(
entry_computation
->
instructions
().
begin
(),
entry_computation
->
instructions
().
end
());
CHECK_EQ
(
1
,
instruction_set
.
erase
(
x
))
<<
"x is not in entry_computation."
;
CHECK_EQ
(
1
,
instruction_set
.
erase
(
y
))
<<
"y is not in entry_computation."
;
CHECK_EQ
(
1
,
instruction_set
.
erase
(
transpose_x
))
<<
"transpose_x is not in entry_computation."
;
CHECK_EQ
(
1
,
instruction_set
.
erase
(
conv
))
<<
"transpose_x is not in entry_computation."
;
CHECK_EQ
(
0
,
instruction_set
.
size
())
<<
"entry_computation should contain exactly 4 instructions."
;
EXPECT_EQ
(
1
,
instruction_set
.
erase
(
x
))
<<
"x is not in entry_computation."
;
EXPECT_EQ
(
1
,
instruction_set
.
erase
(
y
))
<<
"y is not in entry_computation."
;
EXPECT_EQ
(
1
,
instruction_set
.
size
())
<<
"entry_computation should contain exactly 3 instructions."
;
HloInstruction
*
new_conv
=
*
instruction_set
.
begin
();
EXPECT_EQ
(
HloOpcode
::
kConvolution
,
new_conv
->
opcode
());
EXPECT_EQ
(
dnums
.
input_feature_dimension
(),
new_conv
->
convolution_dimension_numbers
().
input_batch_dimension
());
EXPECT_EQ
(
dnums
.
input_batch_dimension
(),
new_conv
->
convolution_dimension_numbers
().
input_feature_dimension
());
EXPECT_EQ
(
dnums
.
spatial_dimensions
(
0
),
new_conv
->
convolution_dimension_numbers
().
spatial_dimensions
(
0
));
EXPECT_EQ
(
dnums
.
spatial_dimensions
(
1
),
new_conv
->
convolution_dimension_numbers
().
spatial_dimensions
(
1
));
}
}
// namespace
...
...
tensorflow/compiler/xla/service/user_computation.cc
浏览文件 @
b3d5ec90
...
...
@@ -20,6 +20,7 @@ limitations under the License.
#include <stack>
#include <unordered_map>
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
...
...
@@ -1843,10 +1844,17 @@ UserComputation::GetEmbeddedComputations(
XLA_VLOG_LINES
(
3
,
session_computation_
.
DebugString
());
std
::
vector
<
VersionedComputationHandle
>
computations
;
std
::
vector
<
int64
>
sorted_handles
;
for
(
const
auto
&
handle_request
:
session_computation_
.
requests
())
{
int64
handle_value
=
handle_request
.
first
;
sorted_handles
.
push_back
(
handle_request
.
first
);
}
std
::
sort
(
sorted_handles
.
begin
(),
sorted_handles
.
end
());
for
(
int64
handle
:
sorted_handles
)
{
const
auto
&
handle_request
=
session_computation_
.
requests
().
find
(
handle
);
CHECK
(
handle_request
!=
session_computation_
.
requests
().
end
());
int64
handle_value
=
handle_request
->
first
;
if
(
handle_value
<=
version
)
{
const
OperationRequest
&
request
=
handle_request
.
second
;
const
OperationRequest
&
request
=
handle_request
->
second
;
switch
(
request
.
request
().
op_case
())
{
case
OpRequest
::
kCallRequest
:
{
CHECK_EQ
(
1
,
request
.
embedded_computation_versions_size
());
...
...
tensorflow/compiler/xla/shape_util.cc
浏览文件 @
b3d5ec90
...
...
@@ -102,6 +102,32 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
return
true
;
}
// Constructs and returns the new shape with the given minor_to_major order in
// its Layout.
StatusOr
<
Shape
>
MakeShapeWithLayoutInternal
(
PrimitiveType
element_type
,
tensorflow
::
gtl
::
ArraySlice
<
int64
>
dimensions
,
tensorflow
::
gtl
::
ArraySlice
<
int64
>
minor_to_major
)
{
if
(
dimensions
.
size
()
!=
minor_to_major
.
size
())
{
return
InvalidArgument
(
"Dimensions size is %ld, but layout size is %ld."
,
dimensions
.
size
(),
minor_to_major
.
size
());
}
if
(
element_type
==
OPAQUE
||
element_type
==
TUPLE
)
{
return
InvalidArgument
(
"Unsupported element type: %s"
,
PrimitiveType_Name
(
element_type
).
c_str
());
}
Shape
shape
=
ShapeUtil
::
MakeShape
(
element_type
,
dimensions
);
auto
min2maj
=
shape
.
mutable_layout
()
->
mutable_minor_to_major
();
min2maj
->
Clear
();
for
(
int64
value
:
minor_to_major
)
{
min2maj
->
Add
(
value
);
}
if
(
!
shape
.
has_layout
())
{
return
InvalidArgument
(
"Shape has no layout."
);
}
TF_RETURN_IF_ERROR
(
ShapeUtil
::
ValidateShape
(
shape
));
return
shape
;
}
}
// namespace
/* static */
bool
ShapeUtil
::
Equal
(
const
Shape
&
lhs
,
const
Shape
&
rhs
)
{
...
...
@@ -152,16 +178,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
/* static */
Shape
ShapeUtil
::
MakeShapeWithLayout
(
PrimitiveType
element_type
,
tensorflow
::
gtl
::
ArraySlice
<
int64
>
dimensions
,
tensorflow
::
gtl
::
ArraySlice
<
int64
>
minor_to_major
)
{
CHECK_EQ
(
dimensions
.
size
(),
minor_to_major
.
size
());
Shape
shape
=
MakeShape
(
element_type
,
dimensions
);
auto
min2maj
=
shape
.
mutable_layout
()
->
mutable_minor_to_major
();
min2maj
->
Clear
();
for
(
int64
value
:
minor_to_major
)
{
min2maj
->
Add
(
value
);
}
DCHECK
(
shape
.
has_layout
());
TF_DCHECK_OK
(
ValidateShape
(
shape
));
return
shape
;
return
MakeShapeWithLayoutInternal
(
element_type
,
dimensions
,
minor_to_major
)
.
ValueOrDie
();
}
/* static */
Shape
ShapeUtil
::
MakeShapeWithMonotonicDim0MajorLayout
(
...
...
@@ -499,11 +517,10 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
// Extract the layout minor-to-major and set it.
TF_ASSIGN_OR_RETURN
(
std
::
vector
<
int64
>
min2maj
,
comma_list_to_int64s
(
layout_string
));
TF_RET_CHECK
(
dimensions
.
size
()
==
min2maj
.
size
());
result
=
ShapeUtil
::
MakeShapeWithLayout
(
primitive_type
,
dimensions
,
min2maj
);
TF_ASSIGN_OR_RETURN
(
result
,
MakeShapeWithLayoutInternal
(
primitive_type
,
dimensions
,
min2maj
));
}
TF_
DCHECK_OK
(
ShapeUtil
::
ValidateShape
(
result
));
TF_
RETURN_IF_ERROR
(
ShapeUtil
::
ValidateShape
(
result
));
return
std
::
move
(
result
);
}
...
...
tensorflow/compiler/xla/tests/BUILD
浏览文件 @
b3d5ec90
...
...
@@ -102,28 +102,18 @@ cc_library(
deps
=
[
":literal_test_util"
,
"//tensorflow/compiler/xla:shape_layout"
,
"//tensorflow/compiler/xla:shape_util"
,
"//tensorflow/compiler/xla:status_macros"
,
"//tensorflow/compiler/xla:statusor"
,
"//tensorflow/compiler/xla:types"
,
"//tensorflow/compiler/xla:util"
,
"//tensorflow/compiler/xla:xla_data_proto"
,
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags"
,
"//tensorflow/compiler/xla/service"
,
"//tensorflow/compiler/xla/service:backend"
,
"//tensorflow/compiler/xla/service:compiler"
,
"//tensorflow/compiler/xla/service:computation_layout"
,
"//tensorflow/compiler/xla/service:computation_placer"
,
"//tensorflow/compiler/xla/service:executable"
,
"//tensorflow/compiler/xla/service:hlo"
,
"//tensorflow/compiler/xla/service:hlo_execution_profile"
,
"//tensorflow/compiler/xla/service:hlo_graph_dumper"
,
"//tensorflow/compiler/xla/service:transfer_manager"
,
"//tensorflow/core:core_cpu_internal"
,
"//tensorflow/compiler/xla/service:hlo_runner"
,
"//tensorflow/core:lib"
,
"//tensorflow/core:stream_executor_no_cuda"
,
"//tensorflow/core:test"
,
"//third_party/eigen3"
,
],
)
...
...
tensorflow/compiler/xla/tests/hlo_test_base.cc
浏览文件 @
b3d5ec90
...
...
@@ -19,24 +19,9 @@ limitations under the License.
#include <string>
#include <utility>
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
...
...
@@ -45,22 +30,6 @@ namespace se = ::perftools::gputools;
namespace
xla
{
// Define this in .cc file to avoid having to include eigen or forward declare
// these types in the header.
struct
HloTestBase
::
EigenThreadPoolWrapper
{
std
::
unique_ptr
<
EigenThreadPoolWrapper
>
pool
;
std
::
unique_ptr
<
Eigen
::
ThreadPoolDevice
>
device
;
};
HloTestBase
::
HloTestBase
()
{}
HloTestBase
::~
HloTestBase
()
{
// Deallocate all the memory allocated during the tests.
for
(
auto
&
allocation
:
allocations_
)
{
backend
().
default_stream_executor
()
->
Deallocate
(
&
allocation
);
}
}
/* static */
std
::
unique_ptr
<
HloModule
>
HloTestBase
::
CreateNewModule
()
{
HloModuleConfig
config
;
...
...
@@ -80,98 +49,25 @@ StatusOr<perftools::gputools::DeviceMemoryBase> HloTestBase::Execute(
tensorflow
::
gtl
::
ArraySlice
<
perftools
::
gputools
::
DeviceMemoryBase
>
arguments
,
Shape
*
result_shape
)
{
TF_ASSIGN_OR_RETURN
(
std
::
unique_ptr
<
Executable
>
executable
,
backend
().
compiler
()
->
Compile
(
std
::
move
(
module
),
backend
().
default_stream_executor
()));
se
::
Stream
stream
(
backend
().
default_stream_executor
());
stream
.
Init
();
ExecutableRunOptions
run_options
;
run_options
.
set_stream
(
&
stream
);
run_options
.
set_allocator
(
backend
().
memory_allocator
());
run_options
.
set_inter_op_thread_pool
(
backend
().
inter_op_thread_pool
());
run_options
.
set_intra_op_thread_pool
(
backend
().
eigen_intra_op_thread_pool_device
());
HloExecutionProfile
hlo_execution_profile
;
ServiceExecutableRunOptions
service_run_options
(
run_options
,
backend
().
StreamBorrower
(),
backend
().
inter_op_thread_pool
());
TF_ASSIGN_OR_RETURN
(
se
::
DeviceMemoryBase
result
,
executable
->
ExecuteOnStream
(
&
service_run_options
,
arguments
,
&
hlo_execution_profile
));
TF_RET_CHECK
(
stream
.
BlockHostUntilDone
());
allocations_
.
push_back
(
result
);
*
result_shape
=
executable
->
result_shape
();
if
(
ShapeUtil
::
IsTuple
(
*
result_shape
))
{
// We must record element buffers of tuples as well to avoid leaks.
DCHECK
(
!
ShapeUtil
::
IsNestedTuple
(
*
result_shape
));
TF_ASSIGN_OR_RETURN
(
std
::
vector
<
se
::
DeviceMemoryBase
>
element_buffers
,
backend
().
transfer_manager
()
->
ShallowCopyTupleFromDevice
(
backend
().
default_stream_executor
(),
result
,
*
result_shape
));
// A tuple may contain the same buffer in more than one element. Keep track
// of the buffers already added to avoid duplicates in allocations_.
std
::
set
<
void
*>
added_opaques
;
for
(
auto
element_buffer
:
element_buffers
)
{
if
(
added_opaques
.
count
(
element_buffer
.
opaque
())
==
0
)
{
CHECK
(
element_buffer
.
opaque
()
!=
nullptr
);
added_opaques
.
insert
(
element_buffer
.
opaque
());
allocations_
.
push_back
(
element_buffer
);
}
}
}
return
result
;
return
runner_
.
Execute
(
std
::
move
(
module
),
arguments
,
result_shape
);
}
se
::
DeviceMemoryBase
HloTestBase
::
TransferToDevice
(
const
Literal
&
literal
)
{
// Allocate memory on the device using the stream executor.
int64
allocation_size
=
backend
().
transfer_manager
()
->
GetByteSizeRequirement
(
literal
.
shape
());
se
::
DeviceMemoryBase
allocation
=
backend
().
default_stream_executor
()
->
AllocateArray
<
uint8
>
(
allocation_size
);
allocations_
.
push_back
(
allocation
);
TF_CHECK_OK
(
backend
().
transfer_manager
()
->
TransferLiteralToDevice
(
backend
().
default_stream_executor
(),
literal
,
&
allocation
));
return
allocation
;
return
runner_
.
TransferToDevice
(
literal
);
}
std
::
unique_ptr
<
Literal
>
HloTestBase
::
TransferFromDevice
(
const
Shape
&
shape
,
se
::
DeviceMemoryBase
device_base
)
{
auto
literal
=
MakeUnique
<
Literal
>
();
TF_CHECK_OK
(
backend
().
transfer_manager
()
->
TransferLiteralFromDevice
(
backend
().
default_stream_executor
(),
device_base
,
shape
,
shape
,
literal
.
get
()));
return
literal
;
return
runner_
.
TransferFromDevice
(
shape
,
device_base
);
}
std
::
unique_ptr
<
Literal
>
HloTestBase
::
ExecuteAndTransfer
(
std
::
unique_ptr
<
HloModule
>
module
,
tensorflow
::
gtl
::
ArraySlice
<
se
::
DeviceMemoryBase
>
arguments
)
{
Shape
result_shape
;
se
::
DeviceMemoryBase
device_base
=
Execute
(
std
::
move
(
module
),
arguments
,
&
result_shape
).
ValueOrDie
();
return
TransferFromDevice
(
result_shape
,
device_base
);
return
runner_
.
ExecuteAndTransfer
(
std
::
move
(
module
),
arguments
);
}
Backend
&
HloTestBase
::
backend
()
{
if
(
!
backend_
)
{
backend_
=
Backend
::
CreateDefaultBackend
().
ConsumeValueOrDie
();
VLOG
(
1
)
<<
"executing on platform "
<<
backend
().
platform
()
->
Name
();
}
return
*
backend_
;
}
Backend
&
HloTestBase
::
backend
()
{
return
runner_
.
backend
();
}
/* static */
string
HloTestBase
::
TestName
()
{
...
...
tensorflow/compiler/xla/tests/hlo_test_base.h
浏览文件 @
b3d5ec90
...
...
@@ -21,12 +21,12 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_runner.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
...
...
@@ -39,10 +39,9 @@ namespace xla {
// building a graph of HLO instructions to run.
class
HloTestBase
:
public
::
testing
::
Test
{
protected:
struct
EigenThreadPoolWrapper
;
HloTestBase
();
HloTestBase
()
{}
~
HloTestBase
()
override
;
~
HloTestBase
()
override
{}
// Creates a new HLO module for a test. The module created will have
// TestName() for its name; it will also automatically populate its debug
...
...
@@ -102,23 +101,12 @@ class HloTestBase : public ::testing::Test {
static
string
TestName
();
// Creates (if necessary) and returns the default backend. If creation fails,
// crashes the program.
//
// This creates the backend lazily so it's possible to instantiate an
// HloTestBase in a program without any backends linked in.
// Returns the backend owned by the HloRunner.
Backend
&
backend
();
// This vector contains handles of all the device memory allocations performed
// by the test. These are deallocated on destruction of the test object.
std
::
vector
<
perftools
::
gputools
::
DeviceMemoryBase
>
allocations_
;
HloRunner
runner_
;
ErrorSpec
error_spec_
{
0.0001
};
std
::
unique_ptr
<
EigenThreadPoolWrapper
>
thread_pool_wrapper_
;
private:
std
::
unique_ptr
<
Backend
>
backend_
;
// Lazily populated. Access via backend().
};
}
// namespace xla
...
...
tensorflow/compiler/xla/tools/BUILD
浏览文件 @
b3d5ec90
...
...
@@ -210,6 +210,18 @@ tf_cc_binary(
],
)
tf_cc_binary
(
name
=
"hlo_proto_to_json"
,
srcs
=
[
"hlo_proto_to_json.cc"
],
deps
=
[
"//tensorflow/compiler/xla:statusor"
,
"//tensorflow/compiler/xla:util"
,
"//tensorflow/compiler/xla/service:hlo_proto"
,
"//tensorflow/core:framework_internal"
,
"//tensorflow/core:lib"
,
],
)
# -----------------------------------------------------------------------------
filegroup
(
...
...
tensorflow/compiler/xla/tools/hlo_proto_to_json.cc
0 → 100644
浏览文件 @
b3d5ec90
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Usage:
// hlo_proto_to_json --input_file=some_binary_proto
// --output_file=path_to_dump_output
//
// Reads one serilized Hlo module, convert it into JSON format and dump into
// some output directory. some_binaray_proto is obtained by serializing Hlo
// module to disk using --xla_dump_hlo_proto_to debug optoin.
#include <stdio.h>
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
using
tensorflow
::
Env
;
using
xla
::
string
;
namespace
xla
{
namespace
tools
{
StatusOr
<
string
>
ToJson
(
const
tensorflow
::
protobuf
::
Message
&
message
)
{
string
json_output
;
tensorflow
::
protobuf
::
util
::
JsonPrintOptions
json_options
;
json_options
.
add_whitespace
=
true
;
json_options
.
always_print_primitive_fields
=
true
;
auto
status
=
tensorflow
::
protobuf
::
util
::
MessageToJsonString
(
message
,
&
json_output
,
json_options
);
if
(
!
status
.
ok
())
{
return
InternalError
(
"MessageToJsonString failed: %s"
,
status
.
error_message
().
data
());
}
return
json_output
;
}
void
RealMain
(
const
string
&
input
,
const
string
&
output
)
{
HloProto
hlo_proto
;
TF_CHECK_OK
(
tensorflow
::
ReadBinaryProto
(
tensorflow
::
Env
::
Default
(),
input
,
&
hlo_proto
))
<<
"Can't open, read, or parse input file "
<<
input
;
auto
statusor
=
ToJson
(
hlo_proto
);
QCHECK
(
statusor
.
ok
())
<<
"Error converting "
<<
input
<<
" to JSON."
<<
statusor
.
status
();
TF_CHECK_OK
(
tensorflow
::
WriteStringToFile
(
tensorflow
::
Env
::
Default
(),
output
,
statusor
.
ValueOrDie
()));
}
}
// namespace tools
}
// namespace xla
int
main
(
int
argc
,
char
**
argv
)
{
string
input_file
,
output_file
;
const
std
::
vector
<
tensorflow
::
Flag
>
flag_list
=
{
tensorflow
::
Flag
(
"input_file"
,
&
input_file
,
"file to convert."
),
tensorflow
::
Flag
(
"output_file"
,
&
output_file
,
"converted file"
),
};
const
string
usage
=
tensorflow
::
Flags
::
Usage
(
argv
[
0
],
flag_list
);
bool
parse_ok
=
tensorflow
::
Flags
::
Parse
(
&
argc
,
argv
,
flag_list
);
tensorflow
::
port
::
InitMain
(
usage
.
c_str
(),
&
argc
,
&
argv
);
QCHECK
(
parse_ok
&&
argc
==
1
)
<<
"
\n
"
<<
usage
;
QCHECK
(
!
input_file
.
empty
())
<<
"--input_file is required"
;
QCHECK
(
!
output_file
.
empty
())
<<
"--output_file is required"
;
xla
::
tools
::
RealMain
(
input_file
,
output_file
);
return
0
;
}
tensorflow/compiler/xla/tools/parser/BUILD
0 → 100644
浏览文件 @
b3d5ec90
# Build file for the Hlo parser.
licenses
([
"notice"
])
# Apache 2.0
package
(
default_visibility
=
[
":friends"
],
)
package_group
(
name
=
"friends"
,
includes
=
[
"//tensorflow/compiler/xla:friends"
,
],
)
# Filegroup used to collect source files for dependency checking.
filegroup
(
name
=
"c_srcs"
,
data
=
glob
([
"**/*.cc"
,
"**/*.h"
,
]),
)
load
(
"//tensorflow:tensorflow.bzl"
,
"tf_cc_test"
)
cc_library
(
name
=
"hlo_lexer"
,
srcs
=
[
"hlo_lexer.cc"
],
hdrs
=
[
"hlo_lexer.h"
,
"hlo_token.h"
,
],
deps
=
[
"//tensorflow/compiler/xla:shape_util"
,
"//tensorflow/compiler/xla:statusor"
,
"//tensorflow/compiler/xla:util"
,
"//tensorflow/compiler/xla:xla_data_proto"
,
"//tensorflow/compiler/xla/service:hlo"
,
"//tensorflow/core:lib"
,
"//tensorflow/core:regexp_internal"
,
],
)
cc_library
(
name
=
"hlo_parser"
,
srcs
=
[
"hlo_parser.cc"
],
hdrs
=
[
"hlo_parser.h"
],
deps
=
[
":hlo_lexer"
,
"//tensorflow/compiler/xla:shape_util"
,
"//tensorflow/compiler/xla:statusor"
,
"//tensorflow/compiler/xla:util"
,
"//tensorflow/compiler/xla:xla_data_proto"
,
"//tensorflow/compiler/xla/service:hlo"
,
"//tensorflow/core:lib"
,
"//tensorflow/core:lib_internal"
,
],
)
tf_cc_test
(
name
=
"hlo_parser_test"
,
size
=
"small"
,
srcs
=
[
"hlo_parser_test.cc"
],
deps
=
[
":hlo_parser"
,
"//tensorflow/core:test"
,
"//tensorflow/core:test_main"
,
],
)
# -----------------------------------------------------------------------------
filegroup
(
name
=
"all_files"
,
srcs
=
glob
(
[
"**/*"
],
exclude
=
[
"**/METADATA"
,
"**/OWNERS"
,
],
),
visibility
=
[
"//tensorflow:__subpackages__"
],
)
tensorflow/compiler/xla/tools/parser/README.md
0 → 100644
浏览文件 @
b3d5ec90
# HloModule string syntax
TODO: Support subcomputations (for fusion, reduce, while, ...).
TODO: Support ops that require extra attributes, e.g. dimensions, strides.
```
yacc
hlo_module
: 'HloModule' name computation
;
computation
: 'ENTRY' name param_list '->' shape instruction_list
;
instruction_list
: '{' instruction_list1 '}'
;
instruction_list1
: instruction
| instruction_list1 instruction
;
instruction
: name '=' shape opcode operands
;
operands
: '(' operands1 ')'
;
operands1
: /*empty*/
| operand
| operands1 ',' operand
;
operand
: shape name
;
param_list
: '(' param_list1 ')'
;
param_list1
: /*empty*/
| param
| param_list1 ',' param
;
param
: name shape
;
shape
: shape_val_
| '(' tuple_elements ')'
;
tuple_elements
: /*empty*/
| shape (',' shape)*
;
name
: identifier ':'
| '%' identifier
;
identifier
: [a-zA-Z_][a-zA-Z0-9_.-]*
;
```
tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
0 → 100644
浏览文件 @
b3d5ec90
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h"
#include <unordered_map>
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/regexp.h"
namespace
xla
{
namespace
tools
{
using
tensorflow
::
StringPiece
;
namespace
{
constexpr
int
kEOF
=
-
1
;
constexpr
int
kError
=
-
2
;
// [a-zA-Z0-9_.-]
bool
IsIdentifierChar
(
char
c
)
{
return
isalnum
(
static_cast
<
unsigned
char
>
(
c
))
||
c
==
'-'
||
c
==
'.'
||
c
==
'_'
;
}
}
// namespace
int
HloLexer
::
GetNextChar
()
{
int
current_char
=
PeekCurrentChar
();
if
(
current_char
!=
kEOF
&&
current_char
!=
kError
)
{
current_ptr_
++
;
}
return
current_char
;
}
int
HloLexer
::
PeekCurrentChar
()
const
{
if
(
current_ptr_
==
buf_
.
end
())
{
return
kEOF
;
}
char
current_char
=
*
current_ptr_
;
if
(
current_char
==
0
)
{
// '\0' should not appear in the middle of the string.
return
kError
;
}
return
static_cast
<
unsigned
char
>
(
current_char
);
}
bool
HloLexer
::
CanDereference
(
const
char
*
ptr
)
const
{
return
ptr
<
buf_
.
end
()
&&
ptr
>=
buf_
.
begin
();
}
StringPiece
HloLexer
::
StringPieceFromPointers
(
const
char
*
begin
,
const
char
*
end
)
const
{
CHECK
(
begin
<=
end
);
CHECK
(
begin
==
buf_
.
end
()
||
CanDereference
(
begin
));
CHECK
(
end
==
buf_
.
end
()
||
CanDereference
(
end
));
return
StringPiece
(
begin
,
end
-
begin
);
}
tensorflow
::
RegexpStringPiece
HloLexer
::
RegexpStringPieceFromPointers
(
const
char
*
begin
,
const
char
*
end
)
const
{
CHECK
(
begin
<=
end
);
CHECK
(
begin
==
buf_
.
end
()
||
CanDereference
(
begin
));
CHECK
(
end
==
buf_
.
end
()
||
CanDereference
(
end
));
return
tensorflow
::
RegexpStringPiece
(
begin
,
end
-
begin
);
}
TokKind
HloLexer
::
LexToken
()
{
while
(
true
)
{
token_start_
=
current_ptr_
;
int
current_char
=
GetNextChar
();
switch
(
current_char
)
{
default:
// [a-zA-Z_]
if
(
isalpha
(
static_cast
<
unsigned
char
>
(
current_char
))
||
current_char
==
'_'
)
{
return
LexIdentifier
();
}
return
TokKind
::
kError
;
case
kEOF
:
// Hit the end of the input buffer.
return
TokKind
::
kEof
;
case
kError
:
// Hit an invalid character in the input buffer.
return
TokKind
::
kError
;
case
' '
:
case
'\t'
:
case
'\n'
:
case
'\r'
:
// Ignore whitespace.
continue
;
case
'0'
:
case
'1'
:
case
'2'
:
case
'3'
:
case
'4'
:
case
'5'
:
case
'6'
:
case
'7'
:
case
'8'
:
case
'9'
:
case
'-'
:
if
(
current_char
==
'-'
&&
PeekCurrentChar
()
==
'>'
)
{
current_ptr_
++
;
return
TokKind
::
kArrow
;
}
return
LexDigitOrNegative
();
case
'='
:
return
TokKind
::
kEqual
;
case
','
:
return
TokKind
::
kComma
;
case
'%'
:
return
LexPercent
();
case
':'
:
return
TokKind
::
kColon
;
case
'['
:
return
TokKind
::
kLsquare
;
case
']'
:
return
TokKind
::
kRsquare
;
case
'{'
:
return
TokKind
::
kLbrace
;
case
'}'
:
return
TokKind
::
kRbrace
;
case
'('
:
return
TokKind
::
kLparen
;
case
')'
:
return
TokKind
::
kRparen
;
}
}
}
// Lex a shape, name, keyword, or opcode.
// shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})?
// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*:
// keyword ::= HloModule, ENTRY, ...
// opcode ::= add, greater-than, ...
TokKind
HloLexer
::
LexIdentifier
()
{
{
auto
consumable
=
RegexpStringPieceFromPointers
(
token_start_
,
buf_
.
end
());
// 'consumable' will be advanced iff its prefix matches the pattern.
static
LazyRE2
shape_pattern
=
{
R"(^(\w*\d*)\[([\d,]*)\](?:\s*{([\d,]*)})?)"
};
if
(
RE2
::
Consume
(
&
consumable
,
*
shape_pattern
))
{
auto
status_or_shape
=
ShapeUtil
::
ParseShapeString
(
StringPieceFromPointers
(
token_start_
,
consumable
.
begin
()));
if
(
status_or_shape
.
ok
())
{
// This is a shape string.
shape_val_
=
status_or_shape
.
ValueOrDie
();
current_ptr_
=
consumable
.
begin
();
return
TokKind
::
kShape
;
}
}
}
while
(
IsIdentifierChar
(
PeekCurrentChar
()))
{
current_ptr_
++
;
}
// If followed by ':', it's a name.
if
(
PeekCurrentChar
()
==
':'
)
{
str_val_
.
assign
(
token_start_
,
current_ptr_
);
current_ptr_
++
;
// skip ':'
return
TokKind
::
kName
;
}
StringPiece
identifier
=
StringPieceFromPointers
(
token_start_
,
current_ptr_
);
// See if this is a keyword.
#define KEYWORD(STR) \
do { \
if (identifier == #STR) { \
return TokKind::kw_##STR; \
} \
} while (false)
KEYWORD
(
true
);
KEYWORD
(
false
);
KEYWORD
(
HloModule
);
KEYWORD
(
ENTRY
);
#undef KEYWORD
// See if this is an opcode.
auto
opcode
=
StringToHloOpcode
(
identifier
.
ToString
());
if
(
opcode
.
ok
())
{
opcode_val_
=
opcode
.
ValueOrDie
();
return
TokKind
::
kOpcode
;
}
current_ptr_
=
token_start_
+
1
;
return
TokKind
::
kError
;
}
// Lex names after a % character.
// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*
TokKind
HloLexer
::
LexPercent
()
{
const
char
*
name_start
=
current_ptr_
;
if
(
isalpha
(
static_cast
<
unsigned
char
>
(
PeekCurrentChar
()))
||
PeekCurrentChar
()
==
'_'
)
{
current_ptr_
++
;
while
(
IsIdentifierChar
(
PeekCurrentChar
()))
{
current_ptr_
++
;
}
str_val_
.
assign
(
name_start
,
current_ptr_
);
return
TokKind
::
kName
;
}
return
TokKind
::
kError
;
}
// Lex integer and floating-point values.
// int [-]?[0-9]+
// fp with exp [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
// fp without exp [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
TokKind
HloLexer
::
LexDigitOrNegative
()
{
auto
consumable
=
RegexpStringPieceFromPointers
(
token_start_
,
buf_
.
end
());
static
LazyRE2
float_pattern
=
{
R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"
};
if
(
RE2
::
Consume
(
&
consumable
,
*
float_pattern
))
{
current_ptr_
=
consumable
.
begin
();
tensorflow
::
strings
::
safe_strtod
(
string
(
token_start_
,
current_ptr_
).
c_str
(),
&
decimal_val_
);
return
TokKind
::
kDecimal
;
}
static
LazyRE2
int_pattern
=
{
R"([-]?\d+)"
};
if
(
RE2
::
Consume
(
&
consumable
,
*
int_pattern
))
{
current_ptr_
=
consumable
.
begin
();
tensorflow
::
strings
::
safe_strto64
(
StringPieceFromPointers
(
token_start_
,
current_ptr_
),
&
int64_val_
);
return
TokKind
::
kInt
;
}
return
TokKind
::
kError
;
}
StringPiece
HloLexer
::
GetCurrentLine
()
const
{
const
char
*
start
=
token_start_
;
const
char
*
end
=
current_ptr_
;
if
(
!
CanDereference
(
start
)
||
!
CanDereference
(
end
))
{
return
"LINE OUT OF RANGE"
;
}
while
(
start
>
buf_
.
begin
()
&&
*
start
!=
'\n'
)
{
start
--
;
}
while
(
end
<
buf_
.
end
()
&&
*
end
!=
'\n'
)
{
end
++
;
}
return
StringPieceFromPointers
(
start
,
end
);
}
}
// namespace tools
}
// namespace xla
tensorflow/compiler/xla/tools/parser/hlo_lexer.h
0 → 100644
浏览文件 @
b3d5ec90
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
#include <string>
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_token.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/types.h"
namespace
xla
{
namespace
tools
{
// Lexer for the HloModule::ToString() format text.
class
HloLexer
{
public:
explicit
HloLexer
(
tensorflow
::
StringPiece
buf
)
:
buf_
(
buf
)
{
current_ptr_
=
buf_
.
begin
();
}
TokKind
Lex
()
{
return
current_kind_
=
LexToken
();
}
TokKind
GetKind
()
const
{
return
current_kind_
;
}
string
GetStrVal
()
const
{
CHECK
(
GetKind
()
==
TokKind
::
kName
);
return
str_val_
;
}
Shape
GetShapeVal
()
const
{
CHECK
(
GetKind
()
==
TokKind
::
kShape
);
return
shape_val_
;
}
HloOpcode
GetOpcodeVal
()
const
{
CHECK
(
GetKind
()
==
TokKind
::
kOpcode
);
return
opcode_val_
;
}
int64
GetInt64Val
()
const
{
CHECK
(
GetKind
()
==
TokKind
::
kInt
);
return
int64_val_
;
}
double
GetDecimalVal
()
const
{
CHECK
(
GetKind
()
==
TokKind
::
kDecimal
);
return
decimal_val_
;
}
// Returns the line of text that is currently being lexed.
tensorflow
::
StringPiece
GetCurrentLine
()
const
;
private:
// Returns the current character. If it's neither the end of input buffer nor
// an invalid character, moves the pointer forward.
int
GetNextChar
();
// Returns the current character.
int
PeekCurrentChar
()
const
;
// Creates StringPiece with the given begin and end. Exits if the begin > end,
// or it's out of the range of the current buffer.
tensorflow
::
StringPiece
StringPieceFromPointers
(
const
char
*
begin
,
const
char
*
end
)
const
;
tensorflow
::
RegexpStringPiece
RegexpStringPieceFromPointers
(
const
char
*
begin
,
const
char
*
end
)
const
;
// Returns true if the given ptr is dereferenceable within the range of the
// current buffer.
bool
CanDereference
(
const
char
*
ptr
)
const
;
TokKind
LexToken
();
TokKind
LexIdentifier
();
TokKind
LexPercent
();
TokKind
LexShape
();
TokKind
LexConstant
();
TokKind
LexDigitOrNegative
();
const
tensorflow
::
StringPiece
buf_
;
const
char
*
current_ptr_
;
// Information about the current token.
const
char
*
token_start_
;
TokKind
current_kind_
;
string
str_val_
;
Shape
shape_val_
;
HloOpcode
opcode_val_
;
int64
int64_val_
;
double
decimal_val_
;
};
}
// namespace tools
}
// namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
tensorflow/compiler/xla/tools/parser/hlo_parser.cc
0 → 100644
浏览文件 @
b3d5ec90
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace
xla
{
namespace
tools
{
namespace
{
using
tensorflow
::
StringPiece
;
using
tensorflow
::
strings
::
StrCat
;
// Parser for the HloModule::ToString() format text.
class
HloParser
{
public:
explicit
HloParser
(
StringPiece
str
)
:
lexer_
(
str
)
{}
// Runs the parser. Returns false if an error occurred.
bool
Run
();
// Returns the parsed HloModule.
std
::
unique_ptr
<
HloModule
>
ConsumeHloModule
()
{
return
std
::
move
(
module_
);
}
// Returns the error information.
string
GetError
()
const
{
return
tensorflow
::
str_util
::
Join
(
error_
,
"
\n
"
);
}
private:
// ParseXXX returns false if an error occurred.
bool
ParseHloModule
();
bool
ParseComputation
();
bool
ParseInstructionList
(
HloComputation
::
Builder
*
builder
);
bool
ParseInstruction
(
HloComputation
::
Builder
*
builder
);
bool
ParseLiteral
(
std
::
unique_ptr
<
Literal
>*
literal
,
const
Shape
&
shape
);
bool
ParseOperands
(
std
::
vector
<
HloInstruction
*>*
operands
,
const
int
expected_size
);
bool
ParseParamList
();
bool
ParseName
(
string
*
result
);
bool
ParseShape
(
Shape
*
result
);
bool
ParseOpcode
(
HloOpcode
*
result
);
bool
ParseInt64
(
int64
*
result
);
bool
ParseDecimal
(
double
*
result
);
bool
ParseBool
(
bool
*
result
);
bool
ParseToken
(
TokKind
kind
,
const
string
&
msg
);
// Logs the current parsing line and the given message. Always returns false.
bool
TokenError
(
StringPiece
msg
);
// If the current token is 'kind', eats it (i.e. lexes the next token) and
// returns true.
bool
EatIfPresent
(
TokKind
kind
);
// Adds the instruction to the pool. Returns false and emits an error if the
// instruction already exists.
bool
AddInstruction
(
const
string
&
name
,
HloInstruction
*
instruction
);
// The map from the instruction name to the instruction. This does not own the
// instructions.
std
::
unordered_map
<
string
,
HloInstruction
*>
instruction_pool_
;
HloLexer
lexer_
;
std
::
unique_ptr
<
HloModule
>
module_
;
std
::
vector
<
string
>
error_
;
};
bool
HloParser
::
TokenError
(
StringPiece
msg
)
{
error_
.
push_back
(
StrCat
(
"was parsing
\"
"
,
lexer_
.
GetCurrentLine
(),
"
\"
; "
,
msg
));
return
false
;
}
bool
HloParser
::
Run
()
{
lexer_
.
Lex
();
return
ParseHloModule
();
}
// ::= 'HloModule' name computation
bool
HloParser
::
ParseHloModule
()
{
if
(
lexer_
.
GetKind
()
!=
TokKind
::
kw_HloModule
)
{
return
TokenError
(
"expects HloModule"
);
}
// Eat 'HloModule'
lexer_
.
Lex
();
string
name
;
if
(
!
ParseName
(
&
name
))
{
return
false
;
}
module_
=
MakeUnique
<
HloModule
>
(
name
);
return
ParseComputation
();
}
// computation ::= 'ENTRY' name param_list '->' shape instruction_list
bool
HloParser
::
ParseComputation
()
{
string
name
;
if
(
!
ParseToken
(
TokKind
::
kw_ENTRY
,
"expects 'ENTRY'"
)
||
!
ParseName
(
&
name
))
{
return
false
;
}
auto
builder
=
MakeUnique
<
HloComputation
::
Builder
>
(
name
);
Shape
shape
;
if
(
!
ParseParamList
()
||
!
ParseToken
(
TokKind
::
kArrow
,
"expects '->'"
)
||
!
ParseShape
(
&
shape
)
||
!
ParseInstructionList
(
builder
.
get
()))
{
return
false
;
}
module_
->
AddEntryComputation
(
builder
->
Build
());
return
true
;
}
// instruction_list ::= '{' instruction_list1 '}'
// instruction_list1 ::= (instruction)+
bool
HloParser
::
ParseInstructionList
(
HloComputation
::
Builder
*
builder
)
{
if
(
!
ParseToken
(
TokKind
::
kLbrace
,
"expects '{' at the beginning of instruction list."
))
{
return
false
;
}
do
{
if
(
!
ParseInstruction
(
builder
))
{
return
false
;
}
}
while
(
lexer_
.
GetKind
()
!=
TokKind
::
kRbrace
);
return
ParseToken
(
TokKind
::
kRbrace
,
"expects '}' at the end of instruction list."
);
}
// instruction ::= name '=' shape opcode operands
bool
HloParser
::
ParseInstruction
(
HloComputation
::
Builder
*
builder
)
{
string
name
;
Shape
shape
;
HloOpcode
opcode
;
std
::
vector
<
HloInstruction
*>
operands
;
if
(
!
ParseName
(
&
name
)
||
!
ParseToken
(
TokKind
::
kEqual
,
"expects '=' in instruction"
)
||
!
ParseShape
(
&
shape
)
||
!
ParseOpcode
(
&
opcode
))
{
return
false
;
}
switch
(
opcode
)
{
case
HloOpcode
::
kParameter
:
{
int64
parameter_number
;
return
ParseToken
(
TokKind
::
kLparen
,
"expects '(' before parameter number"
)
&&
ParseInt64
(
&
parameter_number
)
&&
ParseToken
(
TokKind
::
kRparen
,
"expects ')' after parameter number"
)
&&
AddInstruction
(
name
,
builder
->
AddInstruction
(
HloInstruction
::
CreateParameter
(
parameter_number
,
shape
,
name
)));
}
case
HloOpcode
::
kConstant
:
{
std
::
unique_ptr
<
Literal
>
literal
;
return
ParseToken
(
TokKind
::
kLparen
,
"expects '(' before parameter number"
)
&&
ParseLiteral
(
&
literal
,
shape
)
&&
ParseToken
(
TokKind
::
kRparen
,
"expects ')' after parameter number"
)
&&
AddInstruction
(
name
,
builder
->
AddInstruction
(
HloInstruction
::
CreateConstant
(
std
::
move
(
literal
))));
}
// Unary ops.
case
HloOpcode
::
kAbs
:
case
HloOpcode
::
kRoundNearestAfz
:
case
HloOpcode
::
kBitcast
:
case
HloOpcode
::
kCeil
:
case
HloOpcode
::
kCopy
:
case
HloOpcode
::
kCos
:
case
HloOpcode
::
kExp
:
case
HloOpcode
::
kIsFinite
:
case
HloOpcode
::
kFloor
:
case
HloOpcode
::
kLog
:
case
HloOpcode
::
kNot
:
case
HloOpcode
::
kNegate
:
case
HloOpcode
::
kSign
:
case
HloOpcode
::
kSin
:
case
HloOpcode
::
kSort
:
case
HloOpcode
::
kTanh
:
{
return
ParseOperands
(
&
operands
,
/*expected_size=*/
1
)
&&
AddInstruction
(
name
,
builder
->
AddInstruction
(
HloInstruction
::
CreateUnary
(
shape
,
opcode
,
operands
[
0
])));
}
// Binary ops.
case
HloOpcode
::
kAdd
:
case
HloOpcode
::
kDivide
:
case
HloOpcode
::
kMultiply
:
case
HloOpcode
::
kSubtract
:
case
HloOpcode
::
kEq
:
case
HloOpcode
::
kGe
:
case
HloOpcode
::
kGt
:
case
HloOpcode
::
kLe
:
case
HloOpcode
::
kLt
:
case
HloOpcode
::
kNe
:
case
HloOpcode
::
kDot
:
case
HloOpcode
::
kMaximum
:
case
HloOpcode
::
kMinimum
:
case
HloOpcode
::
kPower
:
case
HloOpcode
::
kRemainder
:
case
HloOpcode
::
kAnd
:
case
HloOpcode
::
kOr
:
case
HloOpcode
::
kShiftLeft
:
case
HloOpcode
::
kShiftRightArithmetic
:
case
HloOpcode
::
kShiftRightLogical
:
{
return
ParseOperands
(
&
operands
,
/*expected_size=*/
2
)
&&
AddInstruction
(
name
,
builder
->
AddInstruction
(
HloInstruction
::
CreateBinary
(
shape
,
opcode
,
operands
[
0
],
operands
[
1
])));
}
// Ternary ops.
case
HloOpcode
::
kClamp
:
case
HloOpcode
::
kSelect
:
{
return
ParseOperands
(
&
operands
,
/*expected_size=*/
3
)
&&
AddInstruction
(
name
,
builder
->
AddInstruction
(
HloInstruction
::
CreateTernary
(
shape
,
opcode
,
operands
[
0
],
operands
[
1
],
operands
[
2
])));
}
// Other supported ops.
case
HloOpcode
::
kConvert
:
{
return
ParseOperands
(
&
operands
,
/*expected_size=*/
1
)
&&
AddInstruction
(
name
,
builder
->
AddInstruction
(
HloInstruction
::
CreateConvert
(
shape
,
operands
[
0
])));
}
case
HloOpcode
::
kCrossReplicaSum
:
{
return
ParseOperands
(
&
operands
,
/*expected_size=*/
1
)
&&
AddInstruction
(
name
,
builder
->
AddInstruction
(
HloInstruction
::
CreateCrossReplicaSum
(
shape
,
operands
[
0
])));
}
case
HloOpcode
::
kReshape
:
{
return
ParseOperands
(
&
operands
,
/*expected_size=*/
1
)
&&
AddInstruction
(
name
,
builder
->
AddInstruction
(
HloInstruction
::
CreateReshape
(
shape
,
operands
[
0
])));
}
case
HloOpcode
::
kBroadcast
:
case
HloOpcode
::
kCall
:
case
HloOpcode
::
kCustomCall
:
case
HloOpcode
::
kConcatenate
:
case
HloOpcode
::
kReducePrecision
:
case
HloOpcode
::
kConvolution
:
case
HloOpcode
::
kGetTupleElement
:
case
HloOpcode
::
kMap
:
case
HloOpcode
::
kPad
:
case
HloOpcode
::
kReduce
:
case
HloOpcode
::
kReduceWindow
:
case
HloOpcode
::
kSelectAndScatter
:
case
HloOpcode
::
kReverse
:
case
HloOpcode
::
kRng
:
case
HloOpcode
::
kSlice
:
case
HloOpcode
::
kDynamicSlice
:
case
HloOpcode
::
kDynamicUpdateSlice
:
case
HloOpcode
::
kTranspose
:
case
HloOpcode
::
kTuple
:
case
HloOpcode
::
kWhile
:
case
HloOpcode
::
kFusion
:
case
HloOpcode
::
kBatchNormTraining
:
case
HloOpcode
::
kBatchNormInference
:
case
HloOpcode
::
kInfeed
:
case
HloOpcode
::
kOutfeed
:
case
HloOpcode
::
kBatchNormGrad
:
case
HloOpcode
::
kRecv
:
case
HloOpcode
::
kSend
:
case
HloOpcode
::
kUpdate
:
case
HloOpcode
::
kIndex
:
case
HloOpcode
::
kTrace
:
return
TokenError
(
StrCat
(
"parsing not yet implemented for op: "
,
HloOpcodeString
(
opcode
)));
}
}
bool
HloParser
::
ParseLiteral
(
std
::
unique_ptr
<
Literal
>*
literal
,
const
Shape
&
shape
)
{
switch
(
shape
.
element_type
())
{
case
PRED
:
bool
b
;
if
(
!
ParseBool
(
&
b
))
{
return
false
;
}
*
literal
=
Literal
::
CreateR0
<
bool
>
(
b
);
return
true
;
case
S32
:
int64
i
;
if
(
!
ParseInt64
(
&
i
))
{
return
false
;
}
*
literal
=
Literal
::
CreateR0
<
int32
>
(
i
);
return
true
;
case
F32
:
double
d
;
if
(
!
ParseDecimal
(
&
d
))
{
return
false
;
}
*
literal
=
Literal
::
CreateR0
<
float
>
(
d
);
return
true
;
default:
return
TokenError
(
StrCat
(
"unsupported constant in shape: "
,
ShapeUtil
::
HumanString
(
shape
)));
}
}
// operands ::= '(' operands1 ')'
// operands1
// ::= /*empty*/
// ::= operand (, operand)*
// operand ::= shape name
bool
HloParser
::
ParseOperands
(
std
::
vector
<
HloInstruction
*>*
operands
,
const
int
expected_size
)
{
if
(
!
ParseToken
(
TokKind
::
kLparen
,
"expects '(' at the beginning of operands"
))
{
return
false
;
}
if
(
lexer_
.
GetKind
()
==
TokKind
::
kRparen
)
{
// empty
}
else
{
do
{
Shape
shape
;
string
name
;
if
(
!
ParseShape
(
&
shape
)
||
!
ParseName
(
&
name
))
{
return
false
;
}
HloInstruction
*
instruction
=
tensorflow
::
gtl
::
FindPtrOrNull
(
instruction_pool_
,
name
);
if
(
!
instruction
)
{
return
TokenError
(
StrCat
(
"instruction does not exist: "
,
name
));
}
operands
->
push_back
(
instruction
);
}
while
(
EatIfPresent
(
TokKind
::
kComma
));
}
if
(
expected_size
!=
operands
->
size
())
{
return
TokenError
(
StrCat
(
"expects "
,
expected_size
,
" operands, but has "
,
operands
->
size
(),
" operands"
));
}
return
ParseToken
(
TokKind
::
kRparen
,
"expects ')' at the end of operands"
);
}
// param_list ::= '(' param_list1 ')'
// param_list1
// ::= /*empty*/
// ::= param (',' param)*
// param ::= name shape
bool
HloParser
::
ParseParamList
()
{
if
(
!
ParseToken
(
TokKind
::
kLparen
,
"expects '(' at the beginning of param list"
))
{
return
false
;
}
if
(
lexer_
.
GetKind
()
==
TokKind
::
kRparen
)
{
// empty
}
else
{
do
{
Shape
shape
;
if
(
!
ParseToken
(
TokKind
::
kName
,
"expects name in parameter"
)
||
!
ParseShape
(
&
shape
))
{
return
false
;
}
}
while
(
EatIfPresent
(
TokKind
::
kComma
));
}
return
ParseToken
(
TokKind
::
kRparen
,
"expects ')' at the end of param list"
);
}
// shape ::= shape_val_
// shape ::= '(' tuple_elements ')'
// tuple_elements
// ::= /*empty*/
// ::= shape (',' shape)*
bool
HloParser
::
ParseShape
(
Shape
*
result
)
{
if
(
EatIfPresent
(
TokKind
::
kLparen
))
{
// Tuple
std
::
vector
<
Shape
>
shapes
;
if
(
lexer_
.
GetKind
()
==
TokKind
::
kRparen
)
{
/*empty*/
}
else
{
// shape (',' shape)*
do
{
shapes
.
emplace_back
();
if
(
!
ParseShape
(
&
shapes
.
back
()))
{
return
false
;
}
}
while
(
EatIfPresent
(
TokKind
::
kComma
));
}
*
result
=
ShapeUtil
::
MakeTupleShape
(
shapes
);
return
ParseToken
(
TokKind
::
kRparen
,
"expects ')' at the end of tuple."
);
}
if
(
lexer_
.
GetKind
()
!=
TokKind
::
kShape
)
{
return
TokenError
(
"expects shape"
);
}
*
result
=
lexer_
.
GetShapeVal
();
lexer_
.
Lex
();
return
true
;
}
bool
HloParser
::
ParseName
(
string
*
result
)
{
VLOG
(
1
)
<<
"ParseName"
;
if
(
lexer_
.
GetKind
()
!=
TokKind
::
kName
)
{
return
TokenError
(
"expects name"
);
}
*
result
=
lexer_
.
GetStrVal
();
lexer_
.
Lex
();
return
true
;
}
bool
HloParser
::
ParseOpcode
(
HloOpcode
*
result
)
{
VLOG
(
1
)
<<
"ParseOpcode"
;
if
(
lexer_
.
GetKind
()
!=
TokKind
::
kOpcode
)
{
return
TokenError
(
"expects opcode"
);
}
*
result
=
lexer_
.
GetOpcodeVal
();
lexer_
.
Lex
();
return
true
;
}
bool
HloParser
::
ParseInt64
(
int64
*
result
)
{
VLOG
(
1
)
<<
"ParseInt64"
;
if
(
lexer_
.
GetKind
()
!=
TokKind
::
kInt
)
{
return
TokenError
(
"expects integer"
);
}
*
result
=
lexer_
.
GetInt64Val
();
lexer_
.
Lex
();
return
true
;
}
bool
HloParser
::
ParseDecimal
(
double
*
result
)
{
switch
(
lexer_
.
GetKind
())
{
case
TokKind
::
kDecimal
:
*
result
=
lexer_
.
GetDecimalVal
();
break
;
case
TokKind
::
kInt
:
*
result
=
static_cast
<
double
>
(
lexer_
.
GetInt64Val
());
break
;
default:
return
TokenError
(
"expects decimal or integer"
);
}
lexer_
.
Lex
();
return
true
;
}
bool
HloParser
::
ParseBool
(
bool
*
result
)
{
if
(
lexer_
.
GetKind
()
!=
TokKind
::
kw_true
&&
lexer_
.
GetKind
()
!=
TokKind
::
kw_false
)
{
return
TokenError
(
"expects true or false"
);
}
*
result
=
lexer_
.
GetKind
()
==
TokKind
::
kw_true
;
lexer_
.
Lex
();
return
true
;
}
bool
HloParser
::
ParseToken
(
TokKind
kind
,
const
string
&
msg
)
{
if
(
lexer_
.
GetKind
()
!=
kind
)
{
return
TokenError
(
msg
);
}
lexer_
.
Lex
();
return
true
;
}
bool
HloParser
::
EatIfPresent
(
TokKind
kind
)
{
if
(
lexer_
.
GetKind
()
!=
kind
)
{
return
false
;
}
lexer_
.
Lex
();
return
true
;
}
bool
HloParser
::
AddInstruction
(
const
string
&
name
,
HloInstruction
*
instruction
)
{
auto
result
=
instruction_pool_
.
insert
({
name
,
instruction
});
if
(
!
result
.
second
)
{
return
TokenError
(
StrCat
(
"instruction already exists: "
,
name
));
}
return
true
;
}
}
// namespace
StatusOr
<
std
::
unique_ptr
<
HloModule
>>
Parse
(
StringPiece
str
)
{
HloParser
parser
(
str
);
if
(
!
parser
.
Run
())
{
return
InvalidArgument
(
"Syntax error: %s"
,
parser
.
GetError
().
c_str
());
}
return
parser
.
ConsumeHloModule
();
}
}
// namespace tools
}
// namespace xla
tensorflow/compiler/xla/tools/parser/hlo_parser.h
0 → 100644
浏览文件 @
b3d5ec90
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace
xla
{
namespace
tools
{
// The api of the hlo parser. Given a string in the HloModule::ToString()
// format, returns the parsed HloModule.
StatusOr
<
std
::
unique_ptr
<
HloModule
>>
Parse
(
tensorflow
::
StringPiece
str
);
}
// namespace tools
}
// namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
0 → 100644
浏览文件 @
b3d5ec90
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include <string>
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace
xla
{
namespace
tools
{
namespace
{
struct
TestData
{
string
test_name
;
string
module_string
;
};
string
TestDataToString
(
const
::
testing
::
TestParamInfo
<
TestData
>&
data
)
{
return
data
.
param
.
test_name
;
}
std
::
vector
<
TestData
>
CreateTestCases
()
{
// clang-format off
return
std
::
vector
<
TestData
>
({
// ax + y
{
"AxpyParam"
,
R"(HloModule axpy_module:
ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
%alpha = f32[2,4]{1,0} parameter(0)
%x = f32[2,4]{1,0} parameter(1)
%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %alpha, f32[2,4]{1,0} %x)
%y = f32[2,4]{1,0} parameter(2)
%add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
}
)"
},
// pred constant
{
"ConstantPred"
,
R"(HloModule constant_pred_module:
ENTRY %constant_pred () -> pred[] {
%constant = pred[] constant(true)
}
)"
},
// s32 constant
{
"ConstantS32"
,
R"(HloModule constant_s32_module:
ENTRY %constant_s32 () -> s32[] {
%constant = s32[] constant(-42)
}
)"
},
// f32 constant, but the value is not a decimal
{
"ConstantF32"
,
R"(HloModule ConstantF32_module:
ENTRY %ConstantF32.v4 () -> f32[] {
%constant = f32[] constant(42)
}
)"
},
// constant + constant
{
"AddConstants"
,
R"(HloModule add_constants_module:
ENTRY %add_constants () -> f32[] {
%constant = f32[] constant(3.14)
%add = f32[] add(f32[] %constant, f32[] %constant)
}
)"
},
// v1 > v2 ? v1 : v2
{
"SelectR1F32"
,
R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module:
ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
%v1 = f32[4]{0} parameter(0)
%v2 = f32[4]{0} parameter(1)
%greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2)
%select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2)
}
)"
}
});
// clang-format on
}
class
HloParserTest
:
public
::
testing
::
Test
,
public
::
testing
::
WithParamInterface
<
TestData
>
{
protected:
void
ExpectSuccess
()
{
const
string
&
original
=
GetParam
().
module_string
;
auto
result
=
Parse
(
original
);
TF_EXPECT_OK
(
result
.
status
());
EXPECT_EQ
(
original
,
result
.
ValueOrDie
()
->
ToString
());
}
};
TEST_P
(
HloParserTest
,
Run
)
{
ExpectSuccess
();
}
INSTANTIATE_TEST_CASE_P
(
HloParserTestSuccessInstantiation
,
HloParserTest
,
::
testing
::
ValuesIn
(
CreateTestCases
()),
TestDataToString
);
TEST_F
(
HloParserTest
,
Empty
)
{
const
string
original
=
""
;
auto
result
=
Parse
(
original
);
EXPECT_NE
(
tensorflow
::
Status
::
OK
(),
result
.
status
());
}
TEST_F
(
HloParserTest
,
Garbage
)
{
const
string
original
=
"HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$"
;
auto
result
=
Parse
(
original
);
EXPECT_NE
(
tensorflow
::
Status
::
OK
(),
result
.
status
());
}
TEST_F
(
HloParserTest
,
WrongOpcode
)
{
const
string
original
=
R"(HloModule wrong_opcode:
ENTRY %blabla (x: f32[], y: f32[]) -> f32[] {
%x = f32[]{} parameter(0)
%y = f32[]{} parameter(1)
%le = pred[]{} le(f32[]{} %x, f32[]{} %y)
}
)"
;
auto
result
=
Parse
(
original
);
EXPECT_NE
(
tensorflow
::
Status
::
OK
(),
result
.
status
());
}
TEST_F
(
HloParserTest
,
WrongShape
)
{
const
string
original
=
R"(HloModule wrong_opcode:
ENTRY %blabla (x: g32[]) -> g32[] {
%x = g32[]{} parameter(0)
}
)"
;
auto
result
=
Parse
(
original
);
EXPECT_NE
(
tensorflow
::
Status
::
OK
(),
result
.
status
());
}
TEST_F
(
HloParserTest
,
WrongOperandsSize
)
{
const
string
original
=
R"(HloModule wrong_opcode:
ENTRY %blabla (x: f32[]) -> pred[] {
%x = f32[]{} parameter(0)
%eq = pred[]{} equal-to(f32[]{} %x)
}
)"
;
auto
result
=
Parse
(
original
);
EXPECT_NE
(
tensorflow
::
Status
::
OK
(),
result
.
status
());
}
TEST_F
(
HloParserTest
,
OperandNotFound
)
{
const
string
original
=
R"(HloModule operand_not_found:
ENTRY %blabla (x: f32[]) -> pred[] {
%x = f32[]{} parameter(0)
%eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y)
}
)"
;
auto
result
=
Parse
(
original
);
EXPECT_NE
(
tensorflow
::
Status
::
OK
(),
result
.
status
());
}
TEST_F
(
HloParserTest
,
MoreConstants
)
{
const
string
original
=
R"(HloModule SelectScalarS32True_module:
ENTRY %SelectScalarS32True.v4 () -> s32[] {
%constant.2 = pred[] constant(true)
%constant.1 = s32[] constant(-42)
%constant = s32[] constant(42)
%select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
}
)"
;
auto
result
=
Parse
(
original
);
TF_EXPECT_OK
(
result
.
status
());
// Constant instructions have no name. The string will be parsed successfully
// but the constant names will not be exactly the same.
}
TEST_F
(
HloParserTest
,
ConstantWithExp
)
{
const
string
original
=
R"(HloModule ConstantWithExp_module:
ENTRY %ConstantWithExp.v4 () -> f32[] {
%constant.1 = f32[] constant(3e+2)
}
)"
;
auto
result
=
Parse
(
original
);
TF_EXPECT_OK
(
result
.
status
());
// The string will be parsed successfully but the output strings are not
// exactly the same, because "3e2" is parsed into value 300 and will be
// printed as "300".
}
TEST_F
(
HloParserTest
,
Tuple
)
{
const
string
original
=
R"(HloModule EmptyTupleCreate_module:
ENTRY %EmptyTupleCreate.v1 () -> () {
%tuple = () tuple()
}
)"
;
auto
result
=
Parse
(
original
);
EXPECT_NE
(
tensorflow
::
Status
::
OK
(),
result
.
status
());
}
}
// namespace
}
// namespace tools
}
// namespace xla
tensorflow/compiler/xla/tools/parser/hlo_token.h
0 → 100644
浏览文件 @
b3d5ec90
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
namespace
xla
{
namespace
tools
{
// Defines different kinds of tokens in a hlo module string.
enum
class
TokKind
{
// Markers
kEof
,
kError
,
// Tokens with no info.
kEqual
,
// =
kComma
,
// ,
kColon
,
// :
kLsquare
,
kRsquare
,
// [ ]
kLbrace
,
kRbrace
,
// { }
kLparen
,
kRparen
,
// ( )
kArrow
,
// ->
// Keywords
kw_HloModule
,
kw_ENTRY
,
kw_true
,
kw_false
,
// Typed tokens.
kName
,
// %foo
kShape
,
// f32[2,3]{1,0}
kOpcode
,
// add
kInt
,
// 42
kDecimal
,
// 4.2
};
}
// namespace tools
}
// namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
tensorflow/compiler/xla/xla.proto
浏览文件 @
b3d5ec90
...
...
@@ -82,8 +82,8 @@ message DebugOptions {
// Dump all HLO modules as text into the provided directory path.
string
xla_generate_hlo_text_to
=
7
;
// Dump compilation artifacts
as JSON
into this directory.
string
xla_dump_
debug_json
_to
=
8
;
// Dump compilation artifacts
in binary proto
into this directory.
string
xla_dump_
hlo_proto
_to
=
8
;
// Instrument the computation to collect per-HLO cycle counts.
bool
xla_hlo_profile
=
9
;
...
...
tensorflow/contrib/batching/BUILD
浏览文件 @
b3d5ec90
...
...
@@ -69,6 +69,28 @@ tf_cc_test(
],
)
cc_library
(
name
=
"adaptive_shared_batch_scheduler"
,
hdrs
=
[
"adaptive_shared_batch_scheduler.h"
],
deps
=
[
":batch_scheduler"
,
"//tensorflow/contrib/batching/util:periodic_function_dynamic"
,
"//tensorflow/core:lib"
,
],
)
tf_cc_test
(
name
=
"adaptive_shared_batch_scheduler_test"
,
srcs
=
[
"adaptive_shared_batch_scheduler_test.cc"
],
deps
=
[
":adaptive_shared_batch_scheduler"
,
"//tensorflow/contrib/batching/test_util:fake_clock_env"
,
"//tensorflow/core:lib"
,
"//tensorflow/core:test"
,
"//tensorflow/core:test_main"
,
],
)
cc_library
(
name
=
"basic_batch_scheduler"
,
hdrs
=
[
"basic_batch_scheduler.h"
],
...
...
tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
0 → 100644
浏览文件 @
b3d5ec90
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
#include <functional>
#include <memory>
#include <queue>
#include <unordered_map>
#include <vector>
#include "tensorflow/contrib/batching/batch_scheduler.h"
#include "tensorflow/contrib/batching/util/periodic_function.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace
tensorflow
{
namespace
serving
{
namespace
internal
{
template
<
typename
TaskType
>
class
ASBSBatch
;
template
<
typename
TaskType
>
class
ASBSQueue
;
}
// namespace internal
// Shared batch scheduler designed to minimize latency. The scheduler keeps
// track of a number of queues (one per model or model version) which are
// continuously enqueuing requests. The scheduler groups the requests into
// batches which it periodically sends off for processing (see
// shared_batch_scheduler.h for more details). The AdaptiveSharedBatchScheduler
// prioritizes batches by age (i.e. the batch's oldest request) irrespective of
// queue. The scheduler will process the oldest batch at an adjustable rate,
// regardless of batch size. The user can provide feedback to help set this rate
// to achieve some goal (i.e. minimize overall latency, limit cpu usage, etc).
//
// The rate (or rather, the corresponding period) is adjusted each time a batch
// is processed, using an exponentially weighted moving average to smooth
// potentially noisy feedback:
// ewma_feedback = ((N - 1) * ewma_feedback + feedback()) / N
// period *= (1 + K * emwa_feedback)
//
// Some potential use cases:
// Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing
// involves serial processing by a device, from a latency perspective it is
// desirable to keep the device evenly loaded, avoiding the need to wait for
// the device to process prior batches.
// feedback = num_pending_on_device() - desired_pending.
// CPU utilization - If the batch processing is cpu dominated, you can reap
// latency gains when underutilized by increasing the processing rate, but
// back the rate off when the load increases to avoid overload.
// feedback = cpu_rate() - desired_cpu_rate.
template
<
typename
TaskType
>
class
AdaptiveSharedBatchScheduler
:
public
std
::
enable_shared_from_this
<
AdaptiveSharedBatchScheduler
<
TaskType
>>
{
public:
struct
Options
{
// The name to use for the pool of batch threads.
string
thread_pool_name
=
{
"batch_threads"
};
// Number of batch processing threads; equivalently the maximum number of
// concurrently running batches.
int64
num_batch_threads
=
port
::
NumSchedulableCPUs
();
// The environment to use (typically only overridden by test code).
Env
*
env
=
Env
::
Default
();
// Initial batch scheduling period in microseconds. Will be altered for
// non-zero rate_feedback.
double
initial_scheduling_period_micros
=
500
;
// Minimum batch scheduling period in microseconds. Recommend setting this
// value greater than 0, otherwise it may take a while to recover from a
// sustained time of negative scheduling_period_feedback (which may occur
// under low load).
double
min_scheduling_period_micros
=
100
;
// Maximum batch scheduling period in microseconds.
double
max_scheduling_period_micros
=
10000
;
// Feedback function used to modify the scheduling period each time a batch
// is scheduled. Should return values roughly O(1), with positive values
// resulting in an increased period.
std
::
function
<
double
()
>
scheduling_period_feedback
=
[]
{
return
0.
;
};
// To handle potentially noisy scheduling_period_feedback, the period is
// adjusted using an exponentially weighted moving average over the previous
// feedback_smoothing_batches batches. Must be greater than 0.
int64
feedback_smoothing_batches
=
10
;
};
// Ownership is shared between the caller of Create() and any queues created
// via AddQueue().
static
Status
Create
(
const
Options
&
options
,
std
::
shared_ptr
<
AdaptiveSharedBatchScheduler
<
TaskType
>>*
scheduler
);
struct
QueueOptions
{
// Maximum size of each batch.
int
max_batch_size
=
1000
;
// Maximum number of enqueued (i.e. non-scheduled) batches.
int
max_enqueued_batches
=
10
;
};
using
BatchProcessor
=
std
::
function
<
void
(
std
::
unique_ptr
<
Batch
<
TaskType
>>
)
>
;
// Adds queue (and its callback) to be managed by this scheduler.
Status
AddQueue
(
const
QueueOptions
&
options
,
BatchProcessor
process_batch_callback
,
std
::
unique_ptr
<
BatchScheduler
<
TaskType
>>*
queue
);
private:
// access to AddBatch, RemoveQueue, GetEnv.
friend
class
internal
::
ASBSQueue
<
TaskType
>
;
explicit
AdaptiveSharedBatchScheduler
(
const
Options
&
options
);
// Batch scheduling function which runs every scheduling_period_ microseconds.
void
ProcessOneBatch
();
// Notifies scheduler of non-empty batch which is eligible for processing.
void
AddBatch
(
internal
::
ASBSBatch
<
TaskType
>*
);
// Removes queue from scheduler.
void
RemoveQueue
(
const
internal
::
ASBSQueue
<
TaskType
>*
queue
);
Env
*
GetEnv
()
const
{
return
options_
.
env
;
}
const
Options
options_
;
struct
BatchCompare
{
bool
operator
()(
const
internal
::
ASBSBatch
<
TaskType
>*
a
,
const
internal
::
ASBSBatch
<
TaskType
>*
b
);
};
// Collection of batches added by AddBatch, ordered by age. Owned by scheduler
// until they are released for processing.
std
::
priority_queue
<
const
internal
::
ASBSBatch
<
TaskType
>*
,
std
::
vector
<
internal
::
ASBSBatch
<
TaskType
>*>
,
BatchCompare
>
batches_
GUARDED_BY
(
mu_
);
// Unowned queues and callbacks added by AddQueue.
std
::
unordered_map
<
const
internal
::
ASBSQueue
<
TaskType
>*
,
BatchProcessor
>
queues_and_callbacks_
GUARDED_BY
(
mu_
);
mutex
mu_
;
// Responsible for running ProcessOneBatch. PeriodicFunction was used in order
// to check for deletion so that the thread can be shut down.
std
::
unique_ptr
<
PeriodicFunction
>
scheduling_thread_
;
// Responsible for running the batch processing callbacks.
std
::
unique_ptr
<
thread
::
ThreadPool
>
batch_thread_pool_
;
// Time interval in microseconds between successive ProcessOneBatch calls.
double
scheduling_period_
;
// Exponentially weighted moving average of
// options_.scheduling_period_feedback() evaluated in each ProcessOneBatch
// call.
double
ewma_feedback_
=
0
;
TF_DISALLOW_COPY_AND_ASSIGN
(
AdaptiveSharedBatchScheduler
);
};
//////////////////////////////////////////////////////////
// Implementation details follow. API users need not read.
namespace
internal
{
// Consolidates tasks into batches, passing them off to the
// AdaptiveSharedBatchScheduler for processing.
template
<
typename
TaskType
>
class
ASBSQueue
:
public
BatchScheduler
<
TaskType
>
{
public:
using
QueueOptions
=
typename
AdaptiveSharedBatchScheduler
<
TaskType
>::
QueueOptions
;
ASBSQueue
(
std
::
shared_ptr
<
AdaptiveSharedBatchScheduler
<
TaskType
>>
scheduler
,
const
QueueOptions
&
options
);
~
ASBSQueue
()
override
;
// Adds task to current batch. Fails if the task size is larger than the batch
// size or if the current batch is full and this queue's number of outstanding
// batches is at its maximum.
Status
Schedule
(
std
::
unique_ptr
<
TaskType
>*
task
)
override
;
// Number of tasks waiting to be scheduled.
size_t
NumEnqueuedTasks
()
const
override
;
// Number of size 1 tasks which could currently be scheduled without failing.
size_t
SchedulingCapacity
()
const
override
;
// Notifies queue that a batch is about to be scheduled; the queue should not
// place any more tasks in this batch.
void
ReleaseBatch
(
const
ASBSBatch
<
TaskType
>*
batch
);
private:
std
::
shared_ptr
<
AdaptiveSharedBatchScheduler
<
TaskType
>>
scheduler_
;
const
QueueOptions
options_
;
// Owned by scheduler_.
ASBSBatch
<
TaskType
>*
current_batch_
GUARDED_BY
(
mu_
)
=
nullptr
;
int64
num_enqueued_batches_
GUARDED_BY
(
mu_
)
=
0
;
int64
num_enqueued_tasks_
GUARDED_BY
(
mu_
)
=
0
;
mutable
mutex
mu_
;
TF_DISALLOW_COPY_AND_ASSIGN
(
ASBSQueue
);
};
// Batch which remembers when and by whom it was created.
template
<
typename
TaskType
>
class
ASBSBatch
:
public
Batch
<
TaskType
>
{
public:
ASBSBatch
(
ASBSQueue
<
TaskType
>*
queue
,
int64
creation_time_micros
)
:
queue_
(
queue
),
creation_time_micros_
(
creation_time_micros
)
{}
~
ASBSBatch
()
override
{}
ASBSQueue
<
TaskType
>*
queue
()
const
{
return
queue_
;
}
int64
creation_time_micros
()
const
{
return
creation_time_micros_
;
}
private:
ASBSQueue
<
TaskType
>*
queue_
;
const
int64
creation_time_micros_
;
TF_DISALLOW_COPY_AND_ASSIGN
(
ASBSBatch
);
};
}
// namespace internal
// ---------------- AdaptiveSharedBatchScheduler ----------------
template
<
typename
TaskType
>
Status
AdaptiveSharedBatchScheduler
<
TaskType
>::
Create
(
const
Options
&
options
,
std
::
shared_ptr
<
AdaptiveSharedBatchScheduler
<
TaskType
>>*
scheduler
)
{
if
(
options
.
num_batch_threads
<
1
)
{
return
errors
::
InvalidArgument
(
"num_batch_threads must be positive; was "
,
options
.
num_batch_threads
);
}
if
(
options
.
min_scheduling_period_micros
<
0
)
{
return
errors
::
InvalidArgument
(
"min_scheduling_period_micros must be >= 0; was "
,
options
.
min_scheduling_period_micros
);
}
if
(
options
.
min_scheduling_period_micros
>
options
.
initial_scheduling_period_micros
)
{
return
errors
::
InvalidArgument
(
"initial_scheduling_period_micros ("
,
options
.
initial_scheduling_period_micros
,
") must be >= min_scheduling_period_micros ("
,
options
.
min_scheduling_period_micros
,
")"
);
}
if
(
options
.
initial_scheduling_period_micros
>
options
.
max_scheduling_period_micros
)
{
return
errors
::
InvalidArgument
(
"initial_scheduling_period_micros ("
,
options
.
initial_scheduling_period_micros
,
") must be <= max_scheduling_period_micros ("
,
options
.
max_scheduling_period_micros
,
")"
);
}
if
(
options
.
feedback_smoothing_batches
<
1
)
{
return
errors
::
InvalidArgument
(
"feedback_smoothing_batches must be positive; was "
,
options
.
feedback_smoothing_batches
);
}
scheduler
->
reset
(
new
AdaptiveSharedBatchScheduler
<
TaskType
>
(
options
));
return
Status
::
OK
();
}
template
<
typename
TaskType
>
AdaptiveSharedBatchScheduler
<
TaskType
>::
AdaptiveSharedBatchScheduler
(
const
Options
&
options
)
:
options_
(
options
),
scheduling_period_
(
options
.
initial_scheduling_period_micros
)
{
PeriodicFunction
::
Options
opts
;
opts
.
thread_name_prefix
=
"scheduling_thread"
;
opts
.
env
=
GetEnv
();
scheduling_thread_
.
reset
(
new
PeriodicFunction
([
this
]
{
ProcessOneBatch
();
},
0
,
opts
));
batch_thread_pool_
.
reset
(
new
thread
::
ThreadPool
(
GetEnv
(),
options
.
thread_pool_name
,
options
.
num_batch_threads
));
}
template
<
typename
TaskType
>
Status
AdaptiveSharedBatchScheduler
<
TaskType
>::
AddQueue
(
const
QueueOptions
&
options
,
BatchProcessor
process_batch_callback
,
std
::
unique_ptr
<
BatchScheduler
<
TaskType
>>*
queue
)
{
if
(
options
.
max_batch_size
<=
0
)
{
return
errors
::
InvalidArgument
(
"max_batch_size must be positive; was "
,
options
.
max_batch_size
);
}
if
(
options
.
max_enqueued_batches
<=
0
)
{
return
errors
::
InvalidArgument
(
"max_enqueued_batches must be positive; was "
,
options
.
max_enqueued_batches
);
}
internal
::
ASBSQueue
<
TaskType
>*
asbs_queue_raw
;
queue
->
reset
(
asbs_queue_raw
=
new
internal
::
ASBSQueue
<
TaskType
>
(
this
->
shared_from_this
(),
options
));
mutex_lock
l
(
mu_
);
queues_and_callbacks_
[
asbs_queue_raw
]
=
process_batch_callback
;
return
Status
::
OK
();
}
template
<
typename
TaskType
>
void
AdaptiveSharedBatchScheduler
<
TaskType
>::
AddBatch
(
internal
::
ASBSBatch
<
TaskType
>*
batch
)
{
mutex_lock
l
(
mu_
);
batches_
.
push
(
batch
);
}
template
<
typename
TaskType
>
void
AdaptiveSharedBatchScheduler
<
TaskType
>::
RemoveQueue
(
const
internal
::
ASBSQueue
<
TaskType
>*
queue
)
{
mutex_lock
l
(
mu_
);
queues_and_callbacks_
.
erase
(
queue
);
}
template
<
typename
TaskType
>
void
AdaptiveSharedBatchScheduler
<
TaskType
>::
ProcessOneBatch
()
{
static
const
double
kFeedbackMultiplier
=
.001
;
internal
::
ASBSBatch
<
TaskType
>*
batch
=
nullptr
;
BatchProcessor
callback
;
const
int64
start_time_micros
=
GetEnv
()
->
NowMicros
();
{
mutex_lock
l
(
mu_
);
if
(
!
batches_
.
empty
())
{
batch
=
batches_
.
top
();
batches_
.
pop
();
callback
=
queues_and_callbacks_
[
batch
->
queue
()];
}
}
if
(
batch
!=
nullptr
)
{
double
feedback
=
options_
.
scheduling_period_feedback
();
const
int64
N
=
options_
.
feedback_smoothing_batches
;
ewma_feedback_
=
((
N
-
1
)
*
ewma_feedback_
+
feedback
)
/
N
;
scheduling_period_
*=
(
1
+
kFeedbackMultiplier
*
ewma_feedback_
);
if
(
scheduling_period_
<
options_
.
min_scheduling_period_micros
)
{
scheduling_period_
=
options_
.
min_scheduling_period_micros
;
}
else
if
(
scheduling_period_
>
options_
.
max_scheduling_period_micros
)
{
scheduling_period_
=
options_
.
max_scheduling_period_micros
;
}
// Queue may destroy itself after ReleaseBatch is called.
batch
->
queue
()
->
ReleaseBatch
(
batch
);
batch_thread_pool_
->
Schedule
([
callback
,
batch
]
{
callback
(
std
::
unique_ptr
<
Batch
<
TaskType
>>
(
batch
));
});
}
const
int64
sleep_time
=
scheduling_period_
-
(
GetEnv
()
->
NowMicros
()
-
start_time_micros
);
if
(
sleep_time
>
0
)
{
GetEnv
()
->
SleepForMicroseconds
(
sleep_time
);
}
}
template
<
typename
TaskType
>
bool
AdaptiveSharedBatchScheduler
<
TaskType
>::
BatchCompare
::
operator
()(
const
internal
::
ASBSBatch
<
TaskType
>*
a
,
const
internal
::
ASBSBatch
<
TaskType
>*
b
)
{
return
a
->
creation_time_micros
()
>
b
->
creation_time_micros
();
}
// ---------------- ASBSQueue ----------------
namespace
internal
{
template
<
typename
TaskType
>
ASBSQueue
<
TaskType
>::
ASBSQueue
(
std
::
shared_ptr
<
AdaptiveSharedBatchScheduler
<
TaskType
>>
scheduler
,
const
QueueOptions
&
options
)
:
scheduler_
(
scheduler
),
options_
(
options
)
{}
template
<
typename
TaskType
>
ASBSQueue
<
TaskType
>::~
ASBSQueue
()
{
// Wait until last batch has been scheduled.
const
int
kSleepMicros
=
1000
;
for
(;;)
{
{
mutex_lock
l
(
mu_
);
if
(
num_enqueued_batches_
==
0
)
{
break
;
}
}
scheduler_
->
GetEnv
()
->
SleepForMicroseconds
(
kSleepMicros
);
}
scheduler_
->
RemoveQueue
(
this
);
}
template
<
typename
TaskType
>
Status
ASBSQueue
<
TaskType
>::
Schedule
(
std
::
unique_ptr
<
TaskType
>*
task
)
{
bool
added_new_batch
=
false
;
size_t
size
=
(
*
task
)
->
size
();
if
(
size
>
options_
.
max_batch_size
)
{
return
errors
::
InvalidArgument
(
"Task size "
,
size
,
" is larger than maximum batch size "
,
options_
.
max_batch_size
);
}
{
mutex_lock
l
(
mu_
);
// Current batch is full, create another if allowed.
if
(
current_batch_
&&
current_batch_
->
size
()
+
size
>
options_
.
max_batch_size
)
{
if
(
num_enqueued_batches_
>=
options_
.
max_enqueued_batches
)
{
return
errors
::
Unavailable
(
"The batch scheduling queue is full"
);
}
current_batch_
->
Close
();
current_batch_
=
nullptr
;
}
if
(
!
current_batch_
)
{
added_new_batch
=
true
;
num_enqueued_batches_
++
;
current_batch_
=
new
ASBSBatch
<
TaskType
>
(
this
,
scheduler_
->
GetEnv
()
->
NowMicros
());
}
current_batch_
->
AddTask
(
std
::
move
(
*
task
));
num_enqueued_tasks_
++
;
}
if
(
added_new_batch
)
scheduler_
->
AddBatch
(
current_batch_
);
return
Status
::
OK
();
}
template
<
typename
TaskType
>
void
ASBSQueue
<
TaskType
>::
ReleaseBatch
(
const
ASBSBatch
<
TaskType
>*
batch
)
{
mutex_lock
l
(
mu_
);
num_enqueued_batches_
--
;
num_enqueued_tasks_
-=
batch
->
num_tasks
();
if
(
batch
==
current_batch_
)
{
current_batch_
->
Close
();
current_batch_
=
nullptr
;
}
}
template
<
typename
TaskType
>
size_t
ASBSQueue
<
TaskType
>::
NumEnqueuedTasks
()
const
{
mutex_lock
l
(
mu_
);
return
num_enqueued_tasks_
;
}
template
<
typename
TaskType
>
size_t
ASBSQueue
<
TaskType
>::
SchedulingCapacity
()
const
{
mutex_lock
l
(
mu_
);
const
int
current_batch_capacity
=
current_batch_
?
options_
.
max_batch_size
-
current_batch_
->
size
()
:
0
;
const
int
spare_batches
=
options_
.
max_enqueued_batches
-
num_enqueued_batches_
;
return
spare_batches
*
options_
.
max_batch_size
+
current_batch_capacity
;
}
}
// namespace internal
}
// namespace serving
}
// namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc
0 → 100644
浏览文件 @
b3d5ec90
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h"
#include "tensorflow/contrib/batching/test_util/fake_clock_env.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/test.h"
namespace
tensorflow
{
namespace
serving
{
namespace
anonymous
{
class
FakeTask
:
public
BatchTask
{
public:
explicit
FakeTask
(
size_t
size
)
:
size_
(
size
)
{}
~
FakeTask
()
override
=
default
;
size_t
size
()
const
override
{
return
size_
;
}
private:
const
size_t
size_
;
TF_DISALLOW_COPY_AND_ASSIGN
(
FakeTask
);
};
// Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on
// that task. Returns the resulting status.
Status
ScheduleTask
(
size_t
task_size
,
BatchScheduler
<
FakeTask
>*
scheduler
)
{
std
::
unique_ptr
<
FakeTask
>
task
(
new
FakeTask
(
task_size
));
Status
status
=
scheduler
->
Schedule
(
&
task
);
// Schedule() should have consumed 'task' iff it returned Status::OK.
CHECK_EQ
(
status
.
ok
(),
task
==
nullptr
);
return
status
;
}
// Creates a thread that waits on 'start' and then advances the fake clock in
// 'env' in a loop until 'stop' is notified. Useful for allowing objects that
// use the clock to be destroyed.
std
::
unique_ptr
<
Thread
>
CreateFakeClockAdvancerThread
(
test_util
::
FakeClockEnv
*
env
,
Notification
*
start
,
Notification
*
stop
)
{
return
std
::
unique_ptr
<
Thread
>
(
Env
::
Default
()
->
StartThread
(
{},
"FakeClockAdvancerThread"
,
[
env
,
start
,
stop
]
{
start
->
WaitForNotification
();
while
(
!
stop
->
HasBeenNotified
())
{
env
->
AdvanceByMicroseconds
(
10
);
Env
::
Default
()
->
SleepForMicroseconds
(
10
);
}
}));
}
TEST
(
AdaptiveSharedBatchSchedulerTest
,
Basic
)
{
for
(
const
bool
delete_scheduler_early
:
{
false
,
true
})
{
for
(
const
bool
delete_queue_1_early
:
{
false
,
true
})
{
int
queue_0_tasks
=
0
;
auto
queue_0_callback
=
[
&
queue_0_tasks
](
std
::
unique_ptr
<
Batch
<
FakeTask
>>
batch
)
{
ASSERT_TRUE
(
batch
->
IsClosed
());
EXPECT_GT
(
batch
->
num_tasks
(),
0
);
for
(
int
i
=
0
;
i
<
batch
->
num_tasks
();
i
++
)
{
queue_0_tasks
+=
batch
->
task
(
i
).
size
();
}
};
int
queue_1_tasks
=
0
;
auto
queue_1_callback
=
[
&
queue_1_tasks
](
std
::
unique_ptr
<
Batch
<
FakeTask
>>
batch
)
{
ASSERT_TRUE
(
batch
->
IsClosed
());
EXPECT_GT
(
batch
->
num_tasks
(),
0
);
for
(
int
i
=
0
;
i
<
batch
->
num_tasks
();
i
++
)
{
queue_1_tasks
+=
batch
->
task
(
i
).
size
();
}
};
{
std
::
shared_ptr
<
AdaptiveSharedBatchScheduler
<
FakeTask
>>
scheduler
;
TF_ASSERT_OK
(
AdaptiveSharedBatchScheduler
<
FakeTask
>::
Create
({},
&
scheduler
));
// Create two queues.
std
::
unique_ptr
<
BatchScheduler
<
FakeTask
>>
queue_0
;
TF_ASSERT_OK
(
scheduler
->
AddQueue
({},
queue_0_callback
,
&
queue_0
));
std
::
unique_ptr
<
BatchScheduler
<
FakeTask
>>
queue_1
;
TF_ASSERT_OK
(
scheduler
->
AddQueue
({},
queue_1_callback
,
&
queue_1
));
if
(
delete_scheduler_early
)
{
// Delete our copy of the scheduler. The queues should keep it alive
// under the covers.
scheduler
=
nullptr
;
}
// Submit tasks to the two queues, and (optionally) remove the queues.
TF_ASSERT_OK
(
ScheduleTask
(
1
,
queue_0
.
get
()));
TF_ASSERT_OK
(
ScheduleTask
(
2
,
queue_1
.
get
()));
TF_ASSERT_OK
(
ScheduleTask
(
3
,
queue_0
.
get
()));
TF_ASSERT_OK
(
ScheduleTask
(
4
,
queue_1
.
get
()));
if
(
delete_queue_1_early
)
{
queue_1
=
nullptr
;
}
TF_ASSERT_OK
(
ScheduleTask
(
5
,
queue_0
.
get
()));
}
EXPECT_EQ
(
queue_0_tasks
,
9
);
EXPECT_EQ
(
queue_1_tasks
,
6
);
}
}
}
TEST
(
AdaptiveSharedBatchSchedulerTest
,
BadOptions
)
{
using
Scheduler
=
AdaptiveSharedBatchScheduler
<
FakeTask
>
;
std
::
shared_ptr
<
Scheduler
>
scheduler
;
Scheduler
::
Options
options
;
options
.
num_batch_threads
=
0
;
EXPECT_FALSE
(
Scheduler
::
Create
(
options
,
&
scheduler
).
ok
());
options
=
Scheduler
::
Options
();
options
.
min_scheduling_period_micros
=
50
;
options
.
max_scheduling_period_micros
=
100
;
options
.
initial_scheduling_period_micros
=
1
;
EXPECT_FALSE
(
Scheduler
::
Create
(
options
,
&
scheduler
).
ok
());
options
=
Scheduler
::
Options
();
options
.
min_scheduling_period_micros
=
50
;
options
.
max_scheduling_period_micros
=
100
;
options
.
initial_scheduling_period_micros
=
1000
;
EXPECT_FALSE
(
Scheduler
::
Create
(
options
,
&
scheduler
).
ok
());
options
=
Scheduler
::
Options
();
options
.
min_scheduling_period_micros
=
100
;
options
.
max_scheduling_period_micros
=
50
;
options
.
initial_scheduling_period_micros
=
75
;
EXPECT_FALSE
(
Scheduler
::
Create
(
options
,
&
scheduler
).
ok
());
options
=
Scheduler
::
Options
();
options
.
feedback_smoothing_batches
=
0
;
EXPECT_FALSE
(
Scheduler
::
Create
(
options
,
&
scheduler
).
ok
());
}
TEST
(
AdaptiveSharedBatchSchedulerTest
,
ObeysQueueOptions
)
{
test_util
::
FakeClockEnv
env
(
Env
::
Default
());
Notification
start_teardown
,
stop_teardown
;
std
::
unique_ptr
<
Thread
>
teardown_thread
=
CreateFakeClockAdvancerThread
(
&
env
,
&
start_teardown
,
&
stop_teardown
);
{
AdaptiveSharedBatchScheduler
<
FakeTask
>::
Options
options
;
options
.
initial_scheduling_period_micros
=
1000
;
options
.
env
=
&
env
;
std
::
shared_ptr
<
AdaptiveSharedBatchScheduler
<
FakeTask
>>
scheduler
;
TF_ASSERT_OK
(
AdaptiveSharedBatchScheduler
<
FakeTask
>::
Create
(
options
,
&
scheduler
));
std
::
unique_ptr
<
BatchScheduler
<
FakeTask
>>
queue_0
;
std
::
unique_ptr
<
BatchScheduler
<
FakeTask
>>
queue_1
;
int
queue_0_tasks
=
0
;
int
queue_1_tasks
=
0
;
auto
queue_0_callback
=
[
&
queue_0_tasks
,
&
env
](
std
::
unique_ptr
<
Batch
<
FakeTask
>>
batch
)
{
ASSERT_TRUE
(
batch
->
IsClosed
());
EXPECT_GT
(
batch
->
num_tasks
(),
0
);
for
(
int
i
=
0
;
i
<
batch
->
num_tasks
();
i
++
)
{
queue_0_tasks
+=
batch
->
task
(
i
).
size
();
}
env
.
SleepForMicroseconds
(
1
);
};
auto
queue_1_callback
=
[
&
queue_1_tasks
,
&
env
](
std
::
unique_ptr
<
Batch
<
FakeTask
>>
batch
)
{
ASSERT_TRUE
(
batch
->
IsClosed
());
EXPECT_GT
(
batch
->
num_tasks
(),
0
);
for
(
int
i
=
0
;
i
<
batch
->
num_tasks
();
i
++
)
{
queue_1_tasks
+=
batch
->
task
(
i
).
size
();
}
env
.
SleepForMicroseconds
(
1
);
};
AdaptiveSharedBatchScheduler
<
FakeTask
>::
QueueOptions
queue_options
;
queue_options
.
max_batch_size
=
10
;
queue_options
.
max_enqueued_batches
=
0
;
// Queue must have max_enqueued_batchs > 1.
EXPECT_FALSE
(
scheduler
->
AddQueue
(
queue_options
,
queue_0_callback
,
&
queue_0
).
ok
());
queue_options
.
max_enqueued_batches
=
2
;
TF_ASSERT_OK
(
scheduler
->
AddQueue
(
queue_options
,
queue_0_callback
,
&
queue_0
));
queue_options
.
max_batch_size
=
0
;
// Queue must have max_batch_size > 0.
EXPECT_FALSE
(
scheduler
->
AddQueue
(
queue_options
,
queue_1_callback
,
&
queue_1
).
ok
());
queue_options
.
max_batch_size
=
2
;
queue_options
.
max_enqueued_batches
=
1
;
TF_ASSERT_OK
(
scheduler
->
AddQueue
(
queue_options
,
queue_1_callback
,
&
queue_1
));
// Wait for scheduling_thread to sleep.
env
.
BlockUntilThreadsAsleep
(
1
);
// Task larger than max_batch_size shouldn't schedule.
EXPECT_FALSE
(
ScheduleTask
(
15
,
queue_0
.
get
()).
ok
());
TF_ASSERT_OK
(
ScheduleTask
(
5
,
queue_0
.
get
()));
TF_ASSERT_OK
(
ScheduleTask
(
5
,
queue_0
.
get
()));
env
.
AdvanceByMicroseconds
(
1
);
// Task larger than max_batch_size shouldn't schedule.
EXPECT_FALSE
(
ScheduleTask
(
3
,
queue_1
.
get
()).
ok
());
TF_ASSERT_OK
(
ScheduleTask
(
1
,
queue_1
.
get
()));
TF_ASSERT_OK
(
ScheduleTask
(
1
,
queue_1
.
get
()));
env
.
AdvanceByMicroseconds
(
1
);
// Exceeds max_enqueued_batches, shouldn't schedule.
EXPECT_FALSE
(
ScheduleTask
(
1
,
queue_1
.
get
()).
ok
());
TF_ASSERT_OK
(
ScheduleTask
(
5
,
queue_0
.
get
()));
// Exceeds max_enqueued_batches, shouldn't schedule.
EXPECT_FALSE
(
ScheduleTask
(
6
,
queue_0
.
get
()).
ok
());
TF_ASSERT_OK
(
ScheduleTask
(
4
,
queue_0
.
get
()));
// Batches should be processed in order from oldest to newest.
env
.
AdvanceByMicroseconds
(
1000
);
env
.
BlockUntilThreadsAsleep
(
2
);
EXPECT_EQ
(
queue_0_tasks
,
10
);
EXPECT_EQ
(
queue_1_tasks
,
0
);
env
.
AdvanceByMicroseconds
(
1000
);
env
.
BlockUntilThreadsAsleep
(
2
);
EXPECT_EQ
(
queue_0_tasks
,
10
);
EXPECT_EQ
(
queue_1_tasks
,
2
);
env
.
AdvanceByMicroseconds
(
1000
);
env
.
BlockUntilThreadsAsleep
(
2
);
EXPECT_EQ
(
queue_0_tasks
,
19
);
EXPECT_EQ
(
queue_1_tasks
,
2
);
start_teardown
.
Notify
();
}
stop_teardown
.
Notify
();
}
TEST
(
AdaptiveSharedBatchSchedulerTest
,
RateFeedback
)
{
test_util
::
FakeClockEnv
env
(
Env
::
Default
());
Notification
start_teardown
,
stop_teardown
;
std
::
unique_ptr
<
Thread
>
teardown_thread
=
CreateFakeClockAdvancerThread
(
&
env
,
&
start_teardown
,
&
stop_teardown
);
{
double
feedback
=
0
;
AdaptiveSharedBatchScheduler
<
FakeTask
>::
Options
options
;
options
.
initial_scheduling_period_micros
=
1000
;
options
.
min_scheduling_period_micros
=
200
;
options
.
max_scheduling_period_micros
=
2000
;
options
.
env
=
&
env
;
options
.
scheduling_period_feedback
=
[
&
feedback
]
{
return
feedback
;
};
options
.
feedback_smoothing_batches
=
1
;
std
::
shared_ptr
<
AdaptiveSharedBatchScheduler
<
FakeTask
>>
scheduler
;
TF_ASSERT_OK
(
AdaptiveSharedBatchScheduler
<
FakeTask
>::
Create
(
options
,
&
scheduler
));
std
::
unique_ptr
<
BatchScheduler
<
FakeTask
>>
queue
;
int
scheduled_items
=
0
;
auto
queue_callback
=
[
&
scheduled_items
,
&
env
](
std
::
unique_ptr
<
Batch
<
FakeTask
>>
batch
)
{
ASSERT_TRUE
(
batch
->
IsClosed
());
EXPECT_GT
(
batch
->
num_tasks
(),
0
);
scheduled_items
=
0
;
for
(
int
i
=
0
;
i
<
batch
->
num_tasks
();
i
++
)
{
scheduled_items
+=
batch
->
task
(
i
).
size
();
}
env
.
SleepForMicroseconds
(
1
);
};
TF_ASSERT_OK
(
scheduler
->
AddQueue
({},
queue_callback
,
&
queue
));
// Wait for scheduling_thread to sleep.
env
.
BlockUntilThreadsAsleep
(
1
);
// Enqueue 6 batches.
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
TF_ASSERT_OK
(
ScheduleTask
(
900
+
i
,
queue
.
get
()));
env
.
AdvanceByMicroseconds
(
1
);
}
feedback
=
-
500
;
env
.
AdvanceByMicroseconds
(
994
);
env
.
BlockUntilThreadsAsleep
(
2
);
// scheduling period = 500 usec.
EXPECT_EQ
(
scheduled_items
,
900
);
env
.
AdvanceByMicroseconds
(
500
);
env
.
BlockUntilThreadsAsleep
(
2
);
// scheduling period = 250 usec.
EXPECT_EQ
(
scheduled_items
,
901
);
feedback
=
0
;
env
.
AdvanceByMicroseconds
(
250
);
env
.
BlockUntilThreadsAsleep
(
2
);
// scheduling period = 250 usec.
EXPECT_EQ
(
scheduled_items
,
902
);
feedback
=
10000
;
// large feedback should hit max_scheduling_period.
env
.
AdvanceByMicroseconds
(
250
);
env
.
BlockUntilThreadsAsleep
(
2
);
// scheduling period = 2000 usec.
EXPECT_EQ
(
scheduled_items
,
903
);
feedback
=
-
10000
;
// large feedback should hit min_scheduling_period.
env
.
AdvanceByMicroseconds
(
1999
);
// No callback scheduled, only scheduling thread sleeping.
env
.
BlockUntilThreadsAsleep
(
1
);
EXPECT_EQ
(
scheduled_items
,
903
);
env
.
AdvanceByMicroseconds
(
1
);
env
.
BlockUntilThreadsAsleep
(
2
);
// scheduling period = 200 usec.
EXPECT_EQ
(
scheduled_items
,
904
);
env
.
AdvanceByMicroseconds
(
200
);
env
.
BlockUntilThreadsAsleep
(
2
);
EXPECT_EQ
(
scheduled_items
,
905
);
start_teardown
.
Notify
();
}
stop_teardown
.
Notify
();
}
TEST
(
AdaptiveSharedBatchSchedulerTest
,
FeedbackSmoothing
)
{
test_util
::
FakeClockEnv
env
(
Env
::
Default
());
Notification
start_teardown
,
stop_teardown
;
std
::
unique_ptr
<
Thread
>
teardown_thread
=
CreateFakeClockAdvancerThread
(
&
env
,
&
start_teardown
,
&
stop_teardown
);
{
double
feedback
=
0
;
AdaptiveSharedBatchScheduler
<
FakeTask
>::
Options
options
;
options
.
initial_scheduling_period_micros
=
1000
;
options
.
env
=
&
env
;
options
.
scheduling_period_feedback
=
[
&
feedback
]
{
return
feedback
;
};
options
.
feedback_smoothing_batches
=
3
;
std
::
shared_ptr
<
AdaptiveSharedBatchScheduler
<
FakeTask
>>
scheduler
;
TF_ASSERT_OK
(
AdaptiveSharedBatchScheduler
<
FakeTask
>::
Create
(
options
,
&
scheduler
));
std
::
unique_ptr
<
BatchScheduler
<
FakeTask
>>
queue
;
int
scheduled_items
=
0
;
auto
queue_callback
=
[
&
scheduled_items
,
&
env
](
std
::
unique_ptr
<
Batch
<
FakeTask
>>
batch
)
{
ASSERT_TRUE
(
batch
->
IsClosed
());
EXPECT_GT
(
batch
->
num_tasks
(),
0
);
scheduled_items
=
0
;
for
(
int
i
=
0
;
i
<
batch
->
num_tasks
();
i
++
)
{
scheduled_items
+=
batch
->
task
(
i
).
size
();
}
env
.
SleepForMicroseconds
(
1
);
};
TF_ASSERT_OK
(
scheduler
->
AddQueue
({},
queue_callback
,
&
queue
));
// Wait for scheduling_thread to sleep.
env
.
BlockUntilThreadsAsleep
(
1
);
// Enqueue 4 batches.
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
TF_ASSERT_OK
(
ScheduleTask
(
900
+
i
,
queue
.
get
()));
env
.
AdvanceByMicroseconds
(
1
);
}
feedback
=
-
300
;
env
.
AdvanceByMicroseconds
(
996
);
env
.
BlockUntilThreadsAsleep
(
2
);
// ewma_feedback = 100, scheduling_period = 900.
EXPECT_EQ
(
scheduled_items
,
900
);
env
.
AdvanceByMicroseconds
(
899
);
// No callback scheduled, only scheduling thread sleeping.
env
.
BlockUntilThreadsAsleep
(
1
);
EXPECT_EQ
(
scheduled_items
,
900
);
env
.
AdvanceByMicroseconds
(
1
);
env
.
BlockUntilThreadsAsleep
(
2
);
// ewma_feedback = 167, scheduling_period = 750.
EXPECT_EQ
(
scheduled_items
,
901
);
env
.
AdvanceByMicroseconds
(
749
);
// No callback scheduled, only scheduling thread sleeping.
env
.
BlockUntilThreadsAsleep
(
1
);
EXPECT_EQ
(
scheduled_items
,
901
);
feedback
=
1000
/
3.
;
env
.
AdvanceByMicroseconds
(
1
);
env
.
BlockUntilThreadsAsleep
(
2
);
// emwa_feedback = 0, scheduling_period = 750.
EXPECT_EQ
(
scheduled_items
,
902
);
env
.
AdvanceByMicroseconds
(
749
);
// No callback scheduled, only scheduling thread sleeping.
env
.
BlockUntilThreadsAsleep
(
1
);
EXPECT_EQ
(
scheduled_items
,
902
);
env
.
AdvanceByMicroseconds
(
1
);
env
.
BlockUntilThreadsAsleep
(
2
);
EXPECT_EQ
(
scheduled_items
,
903
);
start_teardown
.
Notify
();
}
stop_teardown
.
Notify
();
}
TEST
(
AdaptiveSharedBatchSchedulerTest
,
QueueCapacityInfo
)
{
test_util
::
FakeClockEnv
env
(
Env
::
Default
());
Notification
start_teardown
,
stop_teardown
;
std
::
unique_ptr
<
Thread
>
teardown_thread
=
CreateFakeClockAdvancerThread
(
&
env
,
&
start_teardown
,
&
stop_teardown
);
{
AdaptiveSharedBatchScheduler
<
FakeTask
>::
Options
options
;
options
.
initial_scheduling_period_micros
=
1000
;
options
.
env
=
&
env
;
std
::
shared_ptr
<
AdaptiveSharedBatchScheduler
<
FakeTask
>>
scheduler
;
TF_ASSERT_OK
(
AdaptiveSharedBatchScheduler
<
FakeTask
>::
Create
(
options
,
&
scheduler
));
std
::
unique_ptr
<
BatchScheduler
<
FakeTask
>>
queue
;
int
scheduled_items
=
0
;
auto
queue_callback
=
[
&
scheduled_items
,
&
env
](
std
::
unique_ptr
<
Batch
<
FakeTask
>>
batch
)
{
ASSERT_TRUE
(
batch
->
IsClosed
());
EXPECT_GT
(
batch
->
num_tasks
(),
0
);
scheduled_items
=
0
;
for
(
int
i
=
0
;
i
<
batch
->
num_tasks
();
i
++
)
{
scheduled_items
+=
batch
->
task
(
i
).
size
();
}
env
.
SleepForMicroseconds
(
1
);
};
AdaptiveSharedBatchScheduler
<
FakeTask
>::
QueueOptions
queue_options
;
queue_options
.
max_batch_size
=
10
;
queue_options
.
max_enqueued_batches
=
10
;
TF_ASSERT_OK
(
scheduler
->
AddQueue
(
queue_options
,
queue_callback
,
&
queue
));
// Wait for scheduling_thread to sleep.
env
.
BlockUntilThreadsAsleep
(
1
);
// Enqueue 3 tasks.
EXPECT_EQ
(
queue
->
NumEnqueuedTasks
(),
0
);
EXPECT_EQ
(
queue
->
SchedulingCapacity
(),
100
);
TF_ASSERT_OK
(
ScheduleTask
(
5
,
queue
.
get
()));
EXPECT_EQ
(
queue
->
NumEnqueuedTasks
(),
1
);
EXPECT_EQ
(
queue
->
SchedulingCapacity
(),
95
);
env
.
AdvanceByMicroseconds
(
1
);
TF_ASSERT_OK
(
ScheduleTask
(
6
,
queue
.
get
()));
EXPECT_EQ
(
queue
->
NumEnqueuedTasks
(),
2
);
EXPECT_EQ
(
queue
->
SchedulingCapacity
(),
84
);
env
.
AdvanceByMicroseconds
(
1
);
TF_ASSERT_OK
(
ScheduleTask
(
1
,
queue
.
get
()));
EXPECT_EQ
(
queue
->
NumEnqueuedTasks
(),
3
);
EXPECT_EQ
(
queue
->
SchedulingCapacity
(),
83
);
env
.
AdvanceByMicroseconds
(
998
);
env
.
BlockUntilThreadsAsleep
(
2
);
EXPECT_EQ
(
scheduled_items
,
5
);
env
.
AdvanceByMicroseconds
(
1000
);
env
.
BlockUntilThreadsAsleep
(
2
);
EXPECT_EQ
(
scheduled_items
,
7
);
start_teardown
.
Notify
();
}
stop_teardown
.
Notify
();
}
}
// namespace anonymous
}
// namespace serving
}
// namespace tensorflow
tensorflow/contrib/batching/batch_scheduler.h
浏览文件 @
b3d5ec90
...
...
@@ -78,7 +78,7 @@ template <typename TaskType>
class
Batch
{
public:
Batch
()
=
default
;
~
Batch
();
// Blocks until the batch is closed.
virtual
~
Batch
();
// Blocks until the batch is closed.
// Appends 'task' to the batch. After calling AddTask(), the newly-added task
// can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1).
...
...
tensorflow/contrib/cmake/external/cub.cmake
浏览文件 @
b3d5ec90
...
...
@@ -14,7 +14,7 @@
# ==============================================================================
include
(
ExternalProject
)
set
(
cub_URL https://github.com/NVlabs/cub/archive/1.7.4.zip
)
set
(
cub_URL https://
mirror.bazel.build/
github.com/NVlabs/cub/archive/1.7.4.zip
)
set
(
cub_HASH SHA256=20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31
)
set
(
cub_BUILD
${
CMAKE_CURRENT_BINARY_DIR
}
/cub/src/cub
)
set
(
cub_INCLUDE_DIR
${
CMAKE_CURRENT_BINARY_DIR
}
/cub/src/cub
)
...
...
tensorflow/contrib/cmake/external/gif.cmake
浏览文件 @
b3d5ec90
...
...
@@ -15,7 +15,7 @@
include
(
ExternalProject
)
set
(
gif_INCLUDE_DIR
${
CMAKE_CURRENT_BINARY_DIR
}
/external/gif_archive/giflib-5.1.4/
)
set
(
gif_URL http://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz
)
set
(
gif_URL http
s
://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz
)
set
(
gif_HASH SHA256=34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1
)
set
(
gif_INSTALL
${
CMAKE_BINARY_DIR
}
/gif/install
)
set
(
gif_BUILD
${
CMAKE_BINARY_DIR
}
/gif/src/gif
)
...
...
tensorflow/contrib/cmake/external/jpeg.cmake
浏览文件 @
b3d5ec90
...
...
@@ -15,7 +15,7 @@
include
(
ExternalProject
)
set
(
jpeg_INCLUDE_DIR
${
CMAKE_CURRENT_BINARY_DIR
}
/external/jpeg_archive
)
set
(
jpeg_URL http://mirror.bazel.build/www.ijg.org/files/jpegsrc.v9a.tar.gz
)
set
(
jpeg_URL http
s
://mirror.bazel.build/www.ijg.org/files/jpegsrc.v9a.tar.gz
)
set
(
jpeg_HASH SHA256=3a753ea48d917945dd54a2d97de388aa06ca2eb1066cbfdc6652036349fe05a7
)
set
(
jpeg_BUILD
${
CMAKE_CURRENT_BINARY_DIR
}
/jpeg/src/jpeg
)
set
(
jpeg_INSTALL
${
CMAKE_CURRENT_BINARY_DIR
}
/jpeg/install
)
...
...
tensorflow/contrib/cmake/external/lmdb.cmake
浏览文件 @
b3d5ec90
...
...
@@ -15,7 +15,7 @@
include
(
ExternalProject
)
set
(
lmdb_INCLUDE_DIR
${
CMAKE_CURRENT_BINARY_DIR
}
/external/lmdb
)
set
(
lmdb_URL http://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz
)
set
(
lmdb_URL http
s
://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz
)
set
(
lmdb_HASH SHA256=108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326
)
set
(
lmdb_BUILD
${
CMAKE_BINARY_DIR
}
/lmdb/src/lmdb
)
set
(
lmdb_INSTALL
${
CMAKE_BINARY_DIR
}
/lmdb/install
)
...
...
tensorflow/contrib/cmake/external/snappy.cmake
浏览文件 @
b3d5ec90
...
...
@@ -47,4 +47,4 @@ ExternalProject_Add(snappy
)
# actually enables snappy in the source code
add_definitions
(
-DSNAPPY
)
\ No newline at end of file
add_definitions
(
-DTF_USE_SNAPPY
)
tensorflow/contrib/eager/python/BUILD
浏览文件 @
b3d5ec90
...
...
@@ -86,7 +86,7 @@ cuda_py_test(
"//tensorflow/python:client"
,
"//tensorflow/python:client_testlib"
,
"//tensorflow/python/eager:graph_callable"
,
"//tensorflow/python
:platform_
test"
,
"//tensorflow/python
/eager:
test"
,
"//tensorflow/python:variables"
,
],
)
...
...
@@ -132,11 +132,12 @@ py_library(
"//tensorflow/python:array_ops"
,
"//tensorflow/python:dtypes"
,
"//tensorflow/python:framework_ops"
,
"//tensorflow/python:init_ops"
,
"//tensorflow/python:layers_base"
,
"//tensorflow/python:math_ops"
,
"//tensorflow/python:util"
,
"//tensorflow/python:variable_scope"
,
"//tensorflow/python/eager:context"
,
"//tensorflow/python/eager:function"
,
],
)
...
...
@@ -146,6 +147,10 @@ py_test(
srcs_version
=
"PY2AND3"
,
deps
=
[
":metrics"
,
"//tensorflow/python:array_ops"
,
"//tensorflow/python:dtypes"
,
"//tensorflow/python:variables"
,
"//tensorflow/python/eager:context"
,
"//tensorflow/python/eager:test"
,
],
)
...
...
@@ -160,6 +165,8 @@ py_library(
deps
=
[
":datasets"
,
":metrics"
,
"//tensorflow/python/eager:context"
,
"//tensorflow/python/eager:function"
,
],
)
...
...
tensorflow/contrib/eager/python/evaluator_test.py
浏览文件 @
b3d5ec90
...
...
@@ -86,7 +86,7 @@ class EvaluatorTest(test.TestCase):
for
v
in
e
.
metric_variables
:
p
=
v
.
name
.
split
(
"/"
)[
0
]
prefix_count
[
p
]
=
prefix_count
.
get
(
p
,
0
)
+
1
self
.
assertEqual
({
"outer
-
mean"
:
2
,
"mean"
:
2
},
prefix_count
)
self
.
assertEqual
({
"outer
_
mean"
:
2
,
"mean"
:
2
},
prefix_count
)
def
testDataset
(
self
):
e
=
SimpleEvaluator
(
IdentityModel
())
...
...
tensorflow/contrib/eager/python/metrics_impl.py
浏览文件 @
b3d5ec90
...
...
@@ -18,6 +18,10 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
re
from
tensorflow.python.eager
import
context
from
tensorflow.python.eager
import
function
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.ops
import
init_ops
...
...
@@ -25,55 +29,69 @@ from tensorflow.python.ops import math_ops
from
tensorflow.python.ops
import
variable_scope
_to_replace
=
re
.
compile
(
"[^A-Za-z0-9.]"
)
class
Metric
(
object
):
"""A metric holds state for aggregating statistics over an evaluation run.
Users will use Evaluator.add_metric() to add Metric objects to their
evaluation, call them in each step
, and then use
Evaluator.all_metric_results() at the end.
evaluation, call them in each step
(treating the object as a callable),
and then use
Evaluator.all_metric_results() at the end.
Descendants will implement:
* call(): Should follow this pattern:
if not self.built:
self.var = self.add_variable(...)
self.add_update(self.var.assign_add(...))
* aggregate(): Adds in the state from a list of metrics of the same type
as `self`. (Default of summing all the variables will be fine for most
descendants.)
* result(): Computes and returns a final value for the metric
* `build()`: All variables should be created in this method, by calling
`self.add_variable()` as in: `self.var = self.add_variable(...)`
build() will be called in the first invocation of `__call__()`, with
the same arguments passed `call()`.
* `call()`: Has all updates to variables, as in:
self.var.assign_add(...)
* `result()`: Computes and returns a final value for the metric
from the variables in `self`.
Decendants may override, but usually won't need to:
* `aggregate()`: Adds in the state from a list of metrics of the same type
as `self`. (Default is to sum all the variables.)
* `reset()`: Reset all variables to their initial state. (Default is to
zero all the variables.)
Note that users should not call `aggregate()` or `reset()`, they are for
use by TensorFlow infrastructure.
"""
def
__init__
(
self
,
name
=
None
):
self
.
built
=
False
self
.
_
built
=
False
self
.
_vars
=
[]
self
.
_updates
=
[]
self
.
_name
=
name
or
self
.
__class__
.
__name__
# TODO(josh11b): Need some way to make sure two Metrics in the same
# Network have distinct names. Maybe we can get a unique name from
# a name/variable scope?
# TODO(josh11b): self._in_graph_mode = context.in_graph_mode()
name
=
name
or
self
.
__class__
.
__name__
# Replace things like spaces in name to create a valid scope name.
scope_name
=
_to_replace
.
sub
(
"_"
,
name
)
# We create the variable scope now to get the unique name that will
# be used as a variable prefix when build() calls add_variable().
with
variable_scope
.
variable_scope
(
None
,
default_name
=
scope_name
,
use_resource
=
True
,
reuse
=
False
)
as
scope
:
pos
=
scope
.
name
.
rfind
(
scope_name
)
self
.
_name
=
name
+
scope
.
name
[
pos
+
len
(
scope_name
):]
self
.
_scope
=
scope
if
context
.
in_graph_mode
():
# We make self.call() into a graph callable here, so that we can
# return a single op that performs all of the variable updates.
self
.
call
=
function
.
defun
(
self
.
call
)
# ---- API for users ----
def
__call__
(
self
,
*
args
,
**
kwargs
):
# TODO(josh11b): If self._in_graph_mode is true, make self.call() into a
# graph callable here, so that variable updates happen without requiring
# a separate fetch.
# TODO(josh11b): Do we need a separate build() method to separate
# initialization from each update? If so, how do we get the arguments
# to it? We *could* just pass in *args and **kwargs...
if
not
self
.
built
:
# TODO(ashankar): Set up container isolation so there is no chance
# distinct metrics objects accidentally share variables.
# TODO(josh11b): Replace things like spaces in self._name to create
# a valid scope name.
with
variable_scope
.
variable_scope
(
self
.
_name
,
use_resource
=
True
,
reuse
=
False
):
ret
=
self
.
call
(
*
args
,
**
kwargs
)
self
.
built
=
True
else
:
ret
=
self
.
call
(
*
args
,
**
kwargs
)
return
ret
"""Returns op to execute to update this metric for these inputs.
Returns None if eager execution is enabled.
Args:
*args:
**kwargs: A mini-batch of inputs to the Metric, passed on to `call()`.
"""
if
not
self
.
_built
:
with
variable_scope
.
variable_scope
(
self
.
_scope
):
self
.
build
(
*
args
,
**
kwargs
)
self
.
_built
=
True
return
self
.
call
(
*
args
,
**
kwargs
)
@
property
def
name
(
self
):
...
...
@@ -84,10 +102,43 @@ class Metric(object):
return
self
.
_vars
# ---- To be implemented by descendants ---
def
build
(
self
,
*
args
,
**
kwargs
):
"""Method to create variables.
Called by `__call__()` before `call()` for the first time.
Args:
*args:
**kwargs: The arguments to the first invocation of `__call__()`.
`build()` may use the shape and/or dtype of these arguments
when deciding how to create variables.
"""
raise
NotImplementedError
(
"Metrics must define a build() member function"
)
def
call
(
self
,
*
args
,
**
kwargs
):
"""Accumulates statistics for the metric."""
"""Accumulates statistics for the metric. Users should use __call__ instead.
Note: This function is executed as a graph function in graph mode.
This means:
a) Operations on the same resource are executed in textual order.
This should make it easier to do things like add the updated
value of a variable to another, for example.
b) You don't need to worry about collecting the update ops to execute.
All update ops added to the graph by this function will be executed.
As a result, code should generally work the same way with graph or
eager execution.
Args:
*args:
**kwargs: A mini-batch of inputs to the Metric, as passed to
`__call__()`.
"""
raise
NotImplementedError
(
"Metrics must define a call() member function"
)
def
result
(
self
):
# TODO(josh11b): Add an optional summary_writer parameter.
"""Computes and returns a final value for the metric."""
raise
NotImplementedError
(
"Metrics must define a result() member function"
)
# We can support two different strategies of for doing data-parallel
# distributed metric computations:
# * Put metric variables on the first device and rely on small
...
...
@@ -123,16 +174,19 @@ class Metric(object):
self
.
_vars
[
i
].
assign_add
(
math_ops
.
add_n
([
m
.
_vars
[
i
]
for
m
in
metrics
]))
# pylint: enable=protected-access
def
result
(
self
):
# TODO(josh11b): Add an optional summary_writer parameter.
"""Computes and returns a final value for the metric."""
raise
NotImplementedError
(
"Metrics must define a result() member function"
)
def
reset
(
self
):
"""Reset this metric to a freshly initialized state.
Default implementation zeros all the metric variables.
"""
for
v
in
self
.
_vars
:
v
.
assign
(
math_ops
.
zeros_like
(
v
))
# ---- For use by descendants ---
def
add_variable
(
self
,
name
,
shape
=
None
,
dtype
=
None
,
initializer
=
None
):
"""***Only for use by descendants of Metric***."""
if
self
.
built
:
raise
RuntimeError
(
"Can't call add_variable() after a Metric has been "
"built in the first call()."
)
if
self
.
_built
:
raise
RuntimeError
(
"Can't call add_variable() except in build()."
)
v
=
variable_scope
.
get_variable
(
name
,
shape
,
dtype
,
initializer
,
trainable
=
False
,
use_resource
=
True
)
self
.
_vars
.
append
(
v
)
...
...
@@ -144,6 +198,15 @@ class Mean(Metric):
# TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64?
# Or defaults to type of the input if it is tf.float32, else tf.float64?
def
build
(
self
,
values
,
weights
=
None
):
del
values
,
weights
# build() does not use call's arguments
self
.
numer
=
self
.
add_variable
(
name
=
"numer"
,
shape
=
(),
dtype
=
dtypes
.
float64
,
initializer
=
init_ops
.
zeros_initializer
)
self
.
denom
=
self
.
add_variable
(
name
=
"denom"
,
shape
=
(),
dtype
=
dtypes
.
float64
,
initializer
=
init_ops
.
zeros_initializer
)
def
call
(
self
,
values
,
weights
=
None
):
"""Accumulate statistics for computing the mean.
...
...
@@ -154,13 +217,6 @@ class Mean(Metric):
values: Tensor with the per-example value.
weights: Optional weighting of each example. Defaults to 1.
"""
if
not
self
.
built
:
# False only in the first call().
self
.
numer
=
self
.
add_variable
(
name
=
"numer"
,
shape
=
(),
dtype
=
dtypes
.
float64
,
initializer
=
init_ops
.
zeros_initializer
)
self
.
denom
=
self
.
add_variable
(
name
=
"denom"
,
shape
=
(),
dtype
=
dtypes
.
float64
,
initializer
=
init_ops
.
zeros_initializer
)
if
weights
is
None
:
self
.
denom
.
assign_add
(
math_ops
.
cast
(
array_ops
.
size
(
values
),
dtypes
.
float64
))
...
...
@@ -179,6 +235,10 @@ class Mean(Metric):
class
Accuracy
(
Mean
):
"""Calculates how often `predictions` matches `labels`."""
def
build
(
self
,
labels
,
predictions
,
weights
=
None
):
del
labels
,
predictions
,
weights
super
(
Accuracy
,
self
).
build
(
None
)
# Arguments are unused
def
call
(
self
,
labels
,
predictions
,
weights
=
None
):
"""Accumulate accuracy statistics.
...
...
tensorflow/contrib/eager/python/metrics_test.py
浏览文件 @
b3d5ec90
...
...
@@ -19,7 +19,11 @@ from __future__ import division
from
__future__
import
print_function
from
tensorflow.contrib.eager.python
import
metrics
from
tensorflow.python.eager
import
context
from
tensorflow.python.eager
import
test
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.ops
import
variables
class
MetricsTest
(
test
.
TestCase
):
...
...
@@ -56,6 +60,53 @@ class MetricsTest(test.TestCase):
m
([
7
],
[
2
])
# 0 correct, weight 1
self
.
assertEqual
(
2.5
/
5
,
m
.
result
().
numpy
())
def
testTwoMeans
(
self
):
# Verify two metrics with the same class and name don't
# accidentally share state.
m1
=
metrics
.
Mean
()
m2
=
metrics
.
Mean
()
m1
(
0
)
m2
(
2
)
self
.
assertEqual
(
0
,
m1
.
result
().
numpy
())
self
.
assertEqual
(
2
,
m2
.
result
().
numpy
())
self
.
assertNotEqual
(
m1
.
name
,
m2
.
name
)
def
testNamesWithSpaces
(
self
):
# Verify two metrics with the same class and name don't
# accidentally share state.
m1
=
metrics
.
Mean
(
"has space"
)
m2
=
metrics
.
Mean
(
"has space"
)
m2
(
2
)
m1
(
0
)
self
.
assertEqual
(
m1
.
name
,
"has space"
)
self
.
assertEqual
(
m1
.
numer
.
name
,
"has_space/numer:0"
)
self
.
assertEqual
(
m2
.
name
,
"has space_1"
)
self
.
assertEqual
(
m2
.
numer
.
name
,
"has_space_1/numer:0"
)
def
testGraph
(
self
):
with
context
.
graph_mode
(),
self
.
test_session
()
as
sess
:
m
=
metrics
.
Mean
()
p
=
array_ops
.
placeholder
(
dtypes
.
float32
)
accumulate
=
m
(
p
)
variables
.
global_variables_initializer
().
run
()
sess
.
run
(
accumulate
,
feed_dict
=
{
p
:
[
1
,
10
,
100
]})
sess
.
run
(
accumulate
,
feed_dict
=
{
p
:
1000
})
sess
.
run
(
accumulate
,
feed_dict
=
{
p
:
[
10000
,
100000
]})
self
.
assertAllEqual
(
m
.
result
().
eval
(),
111111.0
/
6
)
def
testTwoMeansGraph
(
self
):
# Verify two metrics with the same class and name don't
# accidentally share state.
with
context
.
graph_mode
(),
self
.
test_session
()
as
sess
:
m1
=
metrics
.
Mean
()
m2
=
metrics
.
Mean
()
accumulate1
=
m1
(
0
)
accumulate2
=
m2
(
2
)
variables
.
global_variables_initializer
().
run
()
sess
.
run
([
accumulate1
,
accumulate2
])
self
.
assertEqual
(
0
,
m1
.
result
().
eval
())
self
.
assertEqual
(
2
,
m2
.
result
().
eval
())
if
__name__
==
"__main__"
:
test
.
main
()
tensorflow/contrib/eager/python/saver_test.py
浏览文件 @
b3d5ec90
...
...
@@ -22,6 +22,7 @@ import os
from
tensorflow.contrib.eager.python
import
saver
as
_saver
from
tensorflow.python.eager
import
context
from
tensorflow.python.eager
import
graph_callable
from
tensorflow.python.eager
import
test
from
tensorflow.python.framework
import
dtypes
from
tensorflow.python.framework
import
errors
from
tensorflow.python.framework
import
ops
...
...
@@ -29,7 +30,6 @@ from tensorflow.python.ops import array_ops
from
tensorflow.python.ops
import
init_ops
from
tensorflow.python.ops
import
resource_variable_ops
from
tensorflow.python.ops
import
variable_scope
from
tensorflow.python.platform
import
test
class
SaverTest
(
test
.
TestCase
):
...
...
@@ -38,7 +38,7 @@ class SaverTest(test.TestCase):
return
'/device:GPU:0'
if
context
.
num_gpus
()
else
'/device:CPU:0'
def
testBasics
(
self
):
with
context
.
eager_mode
(),
ops
.
device
(
self
.
_dev
()):
with
ops
.
device
(
self
.
_dev
()):
v1
=
resource_variable_ops
.
ResourceVariable
(
1.0
,
name
=
'v1'
)
def
model
():
return
array_ops
.
constant
(
2.0
)
*
v1
...
...
@@ -54,8 +54,42 @@ class SaverTest(test.TestCase):
saver
.
restore
(
ckpt_prefix
)
self
.
assertEqual
(
v1
.
read_value
().
numpy
(),
1.0
)
def
testRestoreOnCreate
(
self
):
def
testSameNameNoClobbering
(
self
):
with
context
.
eager_mode
(),
ops
.
device
(
self
.
_dev
()):
# Note that this test purposefully uses Graphs rather than
# IsolateTest. Users are more likely to accidentally create the same
# variable name this way.
first_graph
=
ops
.
Graph
()
with
first_graph
.
as_default
():
v1_first_graph
=
resource_variable_ops
.
ResourceVariable
(
1.0
,
name
=
'v1'
)
with
ops
.
Graph
().
as_default
():
v1_second_graph
=
resource_variable_ops
.
ResourceVariable
(
2.0
,
name
=
'v1'
)
saver
=
_saver
.
Saver
([
v1_first_graph
,
v1_second_graph
])
ckpt_prefix
=
os
.
path
.
join
(
test
.
get_temp_dir
(),
'ckpt'
)
with
self
.
assertRaisesRegexp
(
ValueError
,
'v1'
):
saver
.
save
(
ckpt_prefix
)
def
testDifferentGraphError
(
self
):
with
context
.
eager_mode
(),
ops
.
device
(
self
.
_dev
()):
with
ops
.
Graph
().
as_default
():
v1
=
resource_variable_ops
.
ResourceVariable
(
1.0
,
name
=
'v1'
)
with
ops
.
Graph
().
as_default
():
saver
=
_saver
.
Saver
([
v1
])
ckpt_prefix
=
os
.
path
.
join
(
test
.
get_temp_dir
(),
'ckpt'
)
with
self
.
assertRaisesRegexp
(
ValueError
,
'Graph'
):
saver
.
save
(
ckpt_prefix
)
def
testSameObjectOK
(
self
):
with
context
.
eager_mode
(),
ops
.
device
(
self
.
_dev
()):
v1
=
resource_variable_ops
.
ResourceVariable
(
1.0
,
name
=
'v1'
)
# While different objects with the same shared_name are not good, passing
# in the same object multiple times is fine.
saver
=
_saver
.
Saver
([
v1
,
v1
])
ckpt_prefix
=
os
.
path
.
join
(
test
.
get_temp_dir
(),
'ckpt'
)
saver
.
save
(
ckpt_prefix
)
def
testRestoreOnCreate
(
self
):
with
ops
.
device
(
self
.
_dev
()):
def
model
(
init_val
):
v1
=
resource_variable_ops
.
ResourceVariable
(
init_val
,
name
=
'v1'
)
return
array_ops
.
constant
(
1.0
)
*
v1
,
v1
...
...
@@ -71,12 +105,9 @@ class SaverTest(test.TestCase):
# Value is from checkpoint, but not from argument.
ret
,
_
=
model
(
2.0
)
self
.
assertEqual
(
ret
.
numpy
(),
1.0
)
# Create it a second time won't re-assign the checkpoint value.
v1_2
=
resource_variable_ops
.
ResourceVariable
(
3.0
,
name
=
'v1'
)
self
.
assertEqual
(
v1_2
.
read_value
().
numpy
(),
3.0
)
def
testRestoreNotFound
(
self
):
with
context
.
eager_mode
(),
ops
.
device
(
self
.
_dev
()):
with
ops
.
device
(
self
.
_dev
()):
def
model
(
v
):
return
array_ops
.
constant
(
1.0
)
*
v
...
...
@@ -92,7 +123,7 @@ class SaverTest(test.TestCase):
_
=
model
(
resource_variable_ops
.
ResourceVariable
(
1.0
,
name
=
'v2'
))
def
testSaveRestoreGraphCallable
(
self
):
with
context
.
eager_mode
(),
ops
.
device
(
self
.
_dev
()):
with
ops
.
device
(
self
.
_dev
()):
@
graph_callable
.
graph_callable
(
[
graph_callable
.
ShapeAndDtype
(
shape
=
(),
dtype
=
dtypes
.
float32
)])
def
model
(
x
):
...
...
tensorflow/contrib/eager/python/tfe.py
浏览文件 @
b3d5ec90
...
...
@@ -53,6 +53,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@in_eager_mode
@@in_graph_mode
@@IsolateTest
@@run_test_in_graph_and_eager_modes
"""
...
...
@@ -84,6 +85,7 @@ from tensorflow.python.eager.execution_callbacks import nan_callback
from
tensorflow.python.eager.execution_callbacks
import
seterr
from
tensorflow.python.framework.ops
import
enable_eager_execution
from
tensorflow.python.framework.ops
import
eager_run
as
run
from
tensorflow.python.framework.test_util
import
IsolateTest
from
tensorflow.python.framework.test_util
import
run_in_graph_and_eager_modes
as
run_test_in_graph_and_eager_modes
from
tensorflow.python.ops.resource_variable_ops
import
ResourceVariable
as
Variable
from
tensorflow.python.util.all_util
import
remove_undocumented
...
...
tensorflow/contrib/factorization/g3doc/kmeans.md
浏览文件 @
b3d5ec90
...
...
@@ -24,7 +24,11 @@ the full-batch version.
approach for computing the initial cluster assignments that is expensive but is
typically less prone to getting stuck in bad local minima.
We provide distributed implementations of both full-batch and mini-batch
K-Means algorithm. Both K-Means++ and random initialization are supported.
The user can also choose between
**Cosine**
and
**Squared Euclidean**
distance
metrics.
**[k-MC2](https://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/view/12147/11759)**
provides a very fast seeding method that provides high quality centers
comparable to K-Means++ seeding. k-MC2 works particularly well if it is combined
with Mini-batch K-Means.
We provide distributed implementations of both full-batch and mini-batch K-Means
algorithm. K-Means++, k-MC2 and random initialization are supported. The user
can also choose between
**Cosine**
and
**Squared Euclidean**
distance metrics.
tensorflow/contrib/factorization/kernels/clustering_ops.cc
浏览文件 @
b3d5ec90
...
...
@@ -224,6 +224,58 @@ class KmeansPlusPlusInitializationOp : public OpKernel {
REGISTER_KERNEL_BUILDER
(
Name
(
"KmeansPlusPlusInitialization"
).
Device
(
DEVICE_CPU
),
KmeansPlusPlusInitializationOp
);
// Implementation of one single Markov Chain for the k-MC^2 algorithm
class
KMC2ChainInitializationOp
:
public
OpKernel
{
public:
explicit
KMC2ChainInitializationOp
(
OpKernelConstruction
*
context
)
:
OpKernel
(
context
)
{
OP_REQUIRES_OK
(
context
,
context
->
MatchSignature
({
DT_FLOAT
,
DT_INT64
},
{
DT_INT64
}));
}
void
Compute
(
OpKernelContext
*
context
)
override
{
const
Tensor
&
distances_tensor
=
context
->
input
(
0
);
const
Tensor
&
seed_tensor
=
context
->
input
(
1
);
OP_REQUIRES
(
context
,
TensorShapeUtils
::
IsVector
(
distances_tensor
.
shape
()),
InvalidArgument
(
"Input distances should be a vector."
));
OP_REQUIRES
(
context
,
TensorShapeUtils
::
IsScalar
(
seed_tensor
.
shape
()),
InvalidArgument
(
"Input seed should be a scalar."
));
const
int64
num_points
=
distances_tensor
.
dim_size
(
0
);
const
int64
seed
=
seed_tensor
.
scalar
<
int64
>
()();
OP_REQUIRES
(
context
,
num_points
>
0
,
InvalidArgument
(
"Expected distances_tensor.size() > 0."
));
random
::
PhiloxRandom
random
(
seed
);
random
::
SimplePhilox
rng
(
&
random
);
auto
distances
=
distances_tensor
.
flat
<
float
>
();
// Set the initial state of the Markov chain to be the first candidate.
int64
selected_index
=
0
;
float
selected_distance
=
distances
(
selected_index
);
// Build a Markov chain of length num_points.
for
(
int64
i
=
1
;
i
<
num_points
;
++
i
)
{
const
float
candidate_distance
=
distances
(
i
);
// Set the next state of the Markov chain to be the candidate with
// probability min(1, candidate_distance/selected_distance).
if
(
candidate_distance
>
rng
.
RandFloat
()
*
selected_distance
)
{
selected_index
=
i
;
selected_distance
=
candidate_distance
;
}
}
Tensor
*
output_sampled_index_tensor
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
0
,
TensorShape
({}),
&
output_sampled_index_tensor
));
auto
output
=
output_sampled_index_tensor
->
scalar
<
int64
>
();
// Return the last state of the Markov chain as the new center.
output
()
=
selected_index
;
}
};
REGISTER_KERNEL_BUILDER
(
Name
(
"KMC2ChainInitialization"
).
Device
(
DEVICE_CPU
),
KMC2ChainInitializationOp
);
// Operator for computing the nearest neighbors for a set of points.
class
NearestNeighborsOp
:
public
OpKernel
{
public:
...
...
tensorflow/contrib/factorization/kernels/clustering_ops_test.cc
浏览文件 @
b3d5ec90
...
...
@@ -116,6 +116,62 @@ RUN_BM_KmeansPlusPlusInitialization(k3RetriesPerSample);
#undef RUN_BM_KmeansPlusPlusInitialization
#undef BENCHMARK_KMEANS_PLUS_PLUS
Graph
*
SetUpKMC2Initialization
(
int
num_points
)
{
Graph
*
g
=
new
Graph
(
OpRegistry
::
Global
());
Tensor
distances
(
DT_FLOAT
,
TensorShape
({
num_points
}));
Tensor
seed
(
DT_INT64
,
TensorShape
({}));
distances
.
flat
<
float
>
().
setRandom
();
seed
.
flat
<
int64
>
().
setConstant
(
12345
);
TF_CHECK_OK
(
NodeBuilder
(
"KMC2ChainInitializationOp"
,
"KMC2ChainInitialization"
)
.
Input
(
test
::
graph
::
Constant
(
g
,
distances
))
.
Input
(
test
::
graph
::
Constant
(
g
,
seed
))
.
Finalize
(
g
,
nullptr
/* node */
));
return
g
;
}
template
<
int
num_points
,
int
num_to_sample
,
int
num_dims
>
void
BM_KMC2Initialization
(
int
iters
)
{
testing
::
StopTiming
();
testing
::
ItemsProcessed
(
static_cast
<
int64
>
(
iters
)
*
num_points
*
num_dims
*
num_to_sample
);
testing
::
UseRealTime
();
Graph
*
g
=
SetUpKMC2Initialization
(
num_points
);
testing
::
StartTiming
();
test
::
Benchmark
(
"cpu"
,
g
).
Run
(
iters
);
}
#define BENCHMARK_KMC2(p, c, d) \
void BM_KMC2Initialization_##p##_##c##_##d(int iters) { \
BM_KMC2Initialization<p, c, d>(iters); \
} \
BENCHMARK(BM_KMC2Initialization_##p##_##c##_##d);
#define RUN_BM_KMC2Initialization \
BENCHMARK_KMC2(k10Points, k2Centers, k100Dim); \
BENCHMARK_KMC2(k10Points, k5Centers, k100Dim); \
BENCHMARK_KMC2(k10Points, k10Centers, k100Dim); \
BENCHMARK_KMC2(k100Points, k10Centers, k100Dim); \
BENCHMARK_KMC2(k100Points, k20Centers, k100Dim); \
BENCHMARK_KMC2(k100Points, k50Centers, k100Dim); \
BENCHMARK_KMC2(k100Points, k100Centers, k100Dim); \
BENCHMARK_KMC2(k1kPoints, k100Centers, k100Dim); \
BENCHMARK_KMC2(k1kPoints, k200Centers, k100Dim); \
BENCHMARK_KMC2(k1kPoints, k500Centers, k100Dim); \
BENCHMARK_KMC2(k1kPoints, k1kCenters, k100Dim); \
BENCHMARK_KMC2(k10kPoints, k100Centers, k100Dim); \
BENCHMARK_KMC2(k10kPoints, k200Centers, k100Dim); \
BENCHMARK_KMC2(k10kPoints, k500Centers, k100Dim); \
BENCHMARK_KMC2(k10kPoints, k1kCenters, k100Dim); \
BENCHMARK_KMC2(k1MPoints, k100Centers, k100Dim); \
BENCHMARK_KMC2(k1MPoints, k200Centers, k100Dim); \
BENCHMARK_KMC2(k1MPoints, k500Centers, k100Dim); \
BENCHMARK_KMC2(k1MPoints, k1kCenters, k100Dim)
RUN_BM_KMC2Initialization
;
#undef RUN_BM_KMC2Initialization
#undef BENCHMARK_KMC2
Graph
*
SetUpNearestNeighbors
(
int
num_dims
,
int
num_points
,
int
num_centers
,
int
k
)
{
Graph
*
g
=
new
Graph
(
OpRegistry
::
Global
());
...
...
tensorflow/contrib/factorization/ops/clustering_ops.cc
浏览文件 @
b3d5ec90
...
...
@@ -44,6 +44,25 @@ num_retries_per_sample: Scalar. For each row that is sampled, this parameter
samples: Matrix of shape (num_to_sample, d). The sampled rows.
)"
);
REGISTER_OP
(
"KMC2ChainInitialization"
)
.
Input
(
"distances: float32"
)
.
Input
(
"seed: int64"
)
.
Output
(
"index: int64"
)
.
SetShapeFn
(
shape_inference
::
ScalarShape
)
.
Doc
(
R"(
Returns the index of a data point that should be added to the seed set.
Entries in distances are assumed to be squared distances of candidate points to
the already sampled centers in the seed set. The op constructs one Markov chain
of the k-MC^2 algorithm and returns the index of one candidate point to be added
as an additional cluster center.
distances: Vector with squared distances to the closest previously sampled
cluster center for each candidate point.
seed: Scalar. Seed for initializing the random number generator.
index: Scalar with the index of the sampled point.
)"
);
REGISTER_OP
(
"NearestNeighbors"
)
.
Input
(
"points: float32"
)
.
Input
(
"centers: float32"
)
...
...
tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py
浏览文件 @
b3d5ec90
...
...
@@ -55,6 +55,63 @@ class KmeansPlusPlusInitializationTest(test.TestCase):
self
.
runTestWithSeed
(
seed
)
class
KMC2InitializationTest
(
test
.
TestCase
):
def
runTestWithSeed
(
self
,
seed
):
with
self
.
test_session
():
distances
=
np
.
zeros
(
1000
).
astype
(
np
.
float32
)
distances
[
6
]
=
10e7
distances
[
4
]
=
10e3
sampled_point
=
clustering_ops
.
kmc2_chain_initialization
(
distances
,
seed
)
self
.
assertEquals
(
sampled_point
.
eval
(),
6
)
distances
[
6
]
=
0.0
sampled_point
=
clustering_ops
.
kmc2_chain_initialization
(
distances
,
seed
)
self
.
assertEquals
(
sampled_point
.
eval
(),
4
)
def
testBasic
(
self
):
for
seed
in
range
(
100
):
self
.
runTestWithSeed
(
seed
)
class
KMC2InitializationLargeTest
(
test
.
TestCase
):
def
setUp
(
self
):
self
.
_distances
=
np
.
zeros
(
1001
)
self
.
_distances
[
500
]
=
100.0
self
.
_distances
[
1000
]
=
50.0
def
testBasic
(
self
):
with
self
.
test_session
():
counts
=
{}
seed
=
0
for
i
in
range
(
50
):
sample
=
clustering_ops
.
kmc2_chain_initialization
(
self
.
_distances
,
seed
+
i
).
eval
()
counts
[
sample
]
=
counts
.
get
(
sample
,
0
)
+
1
self
.
assertEquals
(
len
(
counts
),
2
)
self
.
assertTrue
(
500
in
counts
)
self
.
assertTrue
(
1000
in
counts
)
self
.
assertGreaterEqual
(
counts
[
500
],
5
)
self
.
assertGreaterEqual
(
counts
[
1000
],
5
)
class
KMC2InitializationCornercaseTest
(
test
.
TestCase
):
def
setUp
(
self
):
self
.
_distances
=
np
.
zeros
(
10
)
def
runTestWithSeed
(
self
,
seed
):
with
self
.
test_session
():
sampled_point
=
clustering_ops
.
kmc2_chain_initialization
(
self
.
_distances
,
seed
)
self
.
assertEquals
(
sampled_point
.
eval
(),
0
)
def
testBasic
(
self
):
for
seed
in
range
(
100
):
self
.
runTestWithSeed
(
seed
)
# A simple test that can be verified by hand.
class
NearestCentersTest
(
test
.
TestCase
):
...
...
tensorflow/contrib/factorization/python/ops/clustering_ops.py
浏览文件 @
b3d5ec90
...
...
@@ -50,6 +50,7 @@ COSINE_DISTANCE = 'cosine'
RANDOM_INIT
=
'random'
KMEANS_PLUS_PLUS_INIT
=
'kmeans_plus_plus'
KMC2_INIT
=
'kmc2'
# The name of the variable holding the cluster centers. Used by the Estimator.
CLUSTERS_VAR_NAME
=
'clusters'
...
...
@@ -66,7 +67,8 @@ class KMeans(object):
use_mini_batch
=
False
,
mini_batch_steps_per_iteration
=
1
,
random_seed
=
0
,
kmeans_plus_plus_num_retries
=
2
):
kmeans_plus_plus_num_retries
=
2
,
kmc2_chain_length
=
200
):
"""Creates an object for generating KMeans clustering graph.
This class implements the following variants of K-means algorithm:
...
...
@@ -95,7 +97,8 @@ class KMeans(object):
exactly like a full-batch version.
Args:
inputs: An input tensor or list of input tensors
inputs: An input tensor or list of input tensors. It is assumed that the
data points have been previously randomly permuted.
num_clusters: An integer tensor specifying the number of clusters. This
argument is ignored if initial_clusters is a tensor or numpy array.
initial_clusters: Specifies the clusters used during initialization. One
...
...
@@ -104,6 +107,7 @@ class KMeans(object):
- a function f(inputs, k) that returns up to k centers from `inputs`.
- "random": Choose centers randomly from `inputs`.
- "kmeans_plus_plus": Use kmeans++ to choose centers from `inputs`.
- "kmc2": Use the fast k-MC2 algorithm to choose centers from `inputs`.
In the last three cases, one batch of `inputs` may not yield
`num_clusters` centers, in which case initialization will require
multiple batches until enough centers are chosen. In the case of
...
...
@@ -121,13 +125,17 @@ class KMeans(object):
additional points to draw from the current distribution before selecting
the best. If a negative value is specified, a heuristic is used to
sample O(log(num_to_sample)) additional points.
kmc2_chain_length: Determines how many candidate points are used by the
k-MC2 algorithm to produce one new cluster centers. If a (mini-)batch
contains less points, one new cluster center is generated from the
(mini-)batch.
Raises:
ValueError: An invalid argument was passed to initial_clusters or
distance_metric.
"""
if
isinstance
(
initial_clusters
,
str
)
and
initial_clusters
not
in
[
RANDOM_INIT
,
KMEANS_PLUS_PLUS_INIT
RANDOM_INIT
,
KMEANS_PLUS_PLUS_INIT
,
KMC2_INIT
]:
raise
ValueError
(
"Unsupported initialization algorithm '%s'"
%
initial_clusters
)
...
...
@@ -141,6 +149,7 @@ class KMeans(object):
self
.
_mini_batch_steps_per_iteration
=
int
(
mini_batch_steps_per_iteration
)
self
.
_random_seed
=
random_seed
self
.
_kmeans_plus_plus_num_retries
=
kmeans_plus_plus_num_retries
self
.
_kmc2_chain_length
=
kmc2_chain_length
@
classmethod
def
_distance_graph
(
cls
,
inputs
,
clusters
,
distance_metric
):
...
...
@@ -302,9 +311,10 @@ class KMeans(object):
else
:
cluster_centers_updated
=
cluster_centers
update_in_steps
=
None
cluster_counts
=
(
variable_scope
.
variable
(
array_ops
.
ones
([
num_clusters
],
dtype
=
dtypes
.
int64
))
if
self
.
_use_mini_batch
else
None
)
cluster_counts
=
(
variable_scope
.
variable
(
array_ops
.
ones
([
num_clusters
],
dtype
=
dtypes
.
int64
))
if
self
.
_use_mini_batch
else
None
)
return
(
cluster_centers
,
cluster_centers_initialized
,
cluster_counts
,
cluster_centers_updated
,
update_in_steps
)
...
...
@@ -359,7 +369,7 @@ class KMeans(object):
init_op
=
_InitializeClustersOpFactory
(
self
.
_inputs
,
num_clusters
,
initial_clusters
,
self
.
_distance_metric
,
self
.
_random_seed
,
self
.
_kmeans_plus_plus_num_retries
,
cluster_centers_var
,
cluster_centers_updated
,
self
.
_kmc2_chain_length
,
cluster_centers_var
,
cluster_centers_updated
,
cluster_centers_initialized
).
op
()
cluster_centers
=
cluster_centers_var
...
...
@@ -520,8 +530,9 @@ class KMeans(object):
array_ops
.
reshape
(
array_ops
.
shape
(
inp
)[
0
],
[
-
1
])),
[
-
1
,
1
]),
cluster_idx
,
num_clusters
))
with
ops
.
colocate_with
(
cluster_centers
,
ignore_existing
=
True
):
new_clusters_centers
=
math_ops
.
add_n
(
cluster_sums
)
/
(
math_ops
.
cast
(
math_ops
.
add_n
(
cluster_counts
),
cluster_sums
[
0
].
dtype
)
+
epsilon
)
new_clusters_centers
=
math_ops
.
add_n
(
cluster_sums
)
/
(
math_ops
.
cast
(
math_ops
.
add_n
(
cluster_counts
),
cluster_sums
[
0
].
dtype
)
+
epsilon
)
if
self
.
_clusters_l2_normalized
():
new_clusters_centers
=
nn_impl
.
l2_normalize
(
new_clusters_centers
,
dim
=
1
)
return
state_ops
.
assign
(
cluster_centers
,
new_clusters_centers
)
...
...
@@ -548,9 +559,12 @@ class _InitializeClustersOpFactory(object):
cluster_centers_initialized := true
"""
# TODO(ccolby): Refactor this class so that kmc2 isn't so much a special case.
def
__init__
(
self
,
inputs
,
num_clusters
,
initial_clusters
,
distance_metric
,
random_seed
,
kmeans_plus_plus_num_retries
,
cluster_centers
,
cluster_centers_updated
,
cluster_centers_initialized
):
random_seed
,
kmeans_plus_plus_num_retries
,
kmc2_chain_length
,
cluster_centers
,
cluster_centers_updated
,
cluster_centers_initialized
):
"""Creates an op factory.
Args:
...
...
@@ -560,6 +574,7 @@ class _InitializeClustersOpFactory(object):
distance_metric: See KMeans constructor.
random_seed: See KMeans constructor.
kmeans_plus_plus_num_retries: See KMeans constructor.
kmc2_chain_length: See KMeans constructor.
cluster_centers: The TF variable holding the initial centers. It may
already contain some centers when the op is executed.
cluster_centers_updated: A second TF variable to hold a copy of the
...
...
@@ -575,6 +590,7 @@ class _InitializeClustersOpFactory(object):
self
.
_distance_metric
=
distance_metric
self
.
_random_seed
=
random_seed
self
.
_kmeans_plus_plus_num_retries
=
kmeans_plus_plus_num_retries
self
.
_kmc2_chain_length
=
kmc2_chain_length
self
.
_cluster_centers
=
cluster_centers
self
.
_cluster_centers_updated
=
cluster_centers_updated
self
.
_cluster_centers_initialized
=
cluster_centers_initialized
...
...
@@ -604,6 +620,90 @@ class _InitializeClustersOpFactory(object):
math_ops
.
to_int64
(
self
.
_num_remaining
),
self
.
_random_seed
,
self
.
_kmeans_plus_plus_num_retries
)
def
_kmc2_multiple_centers
(
self
):
"""Adds new initial cluster centers using the k-MC2 algorithm.
In each call to the op, the provided batch is split into subsets based on
the specified `kmc2_chain_length`. On each subset, a single Markov chain of
the k-MC2 algorithm is used to add *one* new center cluster center. If there
are less than `kmc2_chain_length` points in the subset, a single center is
added using one Markov chain on the full input. It is assumed that the
provided batch has previously been randomly permuted. Otherwise, k-MC2 may
return suboptimal centers.
Returns:
An op that adds new cluster centers.
"""
# The op only operates on the first shard of data.
first_shard
=
self
.
_inputs
[
0
]
# Number of points in the input that can be used.
batch_size
=
array_ops
.
shape
(
first_shard
)[
0
]
# Maximum number of subsets such that the size of each subset is at least
# `kmc2_chain_length`. Final subsets may be larger.
max_to_sample
=
math_ops
.
cast
(
batch_size
/
self
.
_kmc2_chain_length
,
dtype
=
dtypes
.
int32
)
# We sample at least one new center and at most all remaining centers.
num_to_sample
=
math_ops
.
maximum
(
math_ops
.
minimum
(
self
.
_num_remaining
,
max_to_sample
),
1
)
def
_cond
(
i
,
_
):
"""Stopping condition for the while loop."""
return
math_ops
.
less
(
i
,
num_to_sample
)
def
_body
(
i
,
_
):
"""Body that adds a single new center based on a subset."""
def
_sample_random
():
"""Returns a random point as a cluster center."""
# By assumption the batch is reshuffled and _sample_random is always
# called for i=0. Hence, we simply return the first point.
new_center
=
array_ops
.
reshape
(
first_shard
[
0
],
[
1
,
-
1
])
if
self
.
_distance_metric
==
COSINE_DISTANCE
:
new_center
=
nn_impl
.
l2_normalize
(
new_center
,
dim
=
1
)
return
new_center
def
_sample_kmc2_chain
():
"""Returns previous centers as well as a new center sampled using k-MC2.
"""
# Extract the subset from the underlying batch.
start
=
i
*
self
.
_kmc2_chain_length
end
=
start
+
self
.
_kmc2_chain_length
subset
=
first_shard
[
start
:
end
]
# Compute the distances from points in the subset to previous centers.
_
,
distances
=
gen_clustering_ops
.
nearest_neighbors
(
subset
,
self
.
_cluster_centers
,
1
)
# Sample index of new center using k-MC2 Markov chain.
new_center_index
=
gen_clustering_ops
.
kmc2_chain_initialization
(
array_ops
.
squeeze
(
distances
),
self
.
_random_seed
)
# Extract actual new center.
newly_sampled_center
=
array_ops
.
reshape
(
subset
[
new_center_index
],
[
1
,
-
1
])
# Return concatenation with previously sampled centers.
if
self
.
_distance_metric
==
COSINE_DISTANCE
:
newly_sampled_center
=
nn_impl
.
l2_normalize
(
newly_sampled_center
,
dim
=
1
)
return
array_ops
.
concat
([
self
.
_cluster_centers
,
newly_sampled_center
],
0
)
# Obtain a random point if there are no previously sampled centers.
# Otherwise, construct a k-MC2 Markov chain.
new_centers
=
control_flow_ops
.
cond
(
math_ops
.
equal
(
self
.
_num_selected
,
0
),
_sample_random
,
_sample_kmc2_chain
)
# Assign new cluster centers to underlying variable.
assigned_centers
=
state_ops
.
assign
(
self
.
_cluster_centers
,
new_centers
,
validate_shape
=
False
)
if
self
.
_cluster_centers_updated
is
not
self
.
_cluster_centers
:
assigned_centers
=
state_ops
.
assign
(
self
.
_cluster_centers_updated
,
assigned_centers
,
validate_shape
=
False
)
return
i
+
1
,
self
.
_num_clusters
-
array_ops
.
shape
(
assigned_centers
)[
0
]
# Add num_to_sample new data points.
_
,
num_remaining
=
control_flow_ops
.
while_loop
(
_cond
,
_body
,
[
0
,
0
])
return
num_remaining
def
_greedy_batch_sampler
(
self
,
sampler
):
# If the input dataset size is smaller than the number of centers
# remaining, choose the entire input dataset as centers. This can happen
...
...
@@ -657,7 +757,10 @@ class _InitializeClustersOpFactory(object):
with
ops
.
control_dependencies
([
check_ops
.
assert_positive
(
self
.
_num_remaining
),
]):
num_now_remaining
=
self
.
_add_new_centers
()
if
self
.
_initial_clusters
==
KMC2_INIT
:
num_now_remaining
=
self
.
_kmc2_multiple_centers
()
else
:
num_now_remaining
=
self
.
_add_new_centers
()
return
control_flow_ops
.
cond
(
math_ops
.
equal
(
num_now_remaining
,
0
),
lambda
:
state_ops
.
assign
(
self
.
_cluster_centers_initialized
,
True
),
...
...
tensorflow/contrib/framework/__init__.py
浏览文件 @
b3d5ec90
...
...
@@ -37,6 +37,7 @@ See the @{$python/contrib.framework} guide.
@@arg_scope
@@add_arg_scope
@@current_arg_scope
@@has_arg_scope
@@arg_scoped_arguments
...
...
tensorflow/contrib/framework/python/ops/arg_scope.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/learn/python/learn/learn_io/graph_io.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/makefile/Makefile
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/makefile/download_dependencies.sh
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/metrics/python/ops/metric_ops.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/metrics/python/ops/metric_ops_test.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/quantize/BUILD
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/quantize/python/copy_graph_test.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/quantize/python/fold_batch_norms.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/quantize/python/fold_batch_norms_test.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/quantize/python/graph_matcher.py
0 → 100644
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/quantize/python/graph_matcher_test.py
0 → 100644
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/quantize/python/quantize_parameterized_test.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/rnn/BUILD
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/seq2seq/ops/beam_search_ops.cc
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/contrib/training/python/training/bucket_ops.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/core/BUILD
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/core/framework/api_def.proto
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/core/framework/op_gen_lib.cc
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/core/framework/op_gen_lib_test.cc
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/core/graph/graph_constructor.cc
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/core/grappler/grappler_item_builder.cc
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/core/kernels/mkl_transpose_op.cc
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/core/kernels/transpose_functor.h
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/core/platform/default/build_config.bzl
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/core/platform/posix/port.cc
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/co
ntrib
/s3/BUILD
→
tensorflow/co
re/platform
/s3/BUILD
浏览文件 @
b3d5ec90
文件已移动
tensorflow/co
ntrib
/s3/s3_crypto.cc
→
tensorflow/co
re/platform
/s3/s3_crypto.cc
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/co
ntrib
/s3/s3_crypto.h
→
tensorflow/co
re/platform
/s3/s3_crypto.h
浏览文件 @
b3d5ec90
文件已移动
tensorflow/co
ntrib
/s3/s3_file_system.cc
→
tensorflow/co
re/platform
/s3/s3_file_system.cc
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/co
ntrib
/s3/s3_file_system.h
→
tensorflow/co
re/platform
/s3/s3_file_system.h
浏览文件 @
b3d5ec90
文件已移动
tensorflow/co
ntrib
/s3/s3_file_system_test.cc
→
tensorflow/co
re/platform
/s3/s3_file_system_test.cc
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/core/platform/windows/port.cc
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/examples/learn/iris.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/examples/learn/random_forest_mnist.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/examples/learn/text_classification_character_rnn.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/BUILD
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/eager/backprop.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/eager/backprop_test.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/eager/context.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/eager/function.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/eager/function_test.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/eager/graph_callable.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/estimator/export/export.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/estimator/export/export_output.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/estimator/export/export_output_test.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/framework/test_util.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/framework/test_util_test.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/kernel_tests/BUILD
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/kernel_tests/qr_op_test.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/kernel_tests/resource_variable_ops_test.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/kernel_tests/rnn_test.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/ops/linalg_grad.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/ops/resource_variable_ops.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/ops/rnn.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/pywrap_tfe.i
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/saved_model/signature_def_utils_impl.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/training/input.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/training/saver.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/python/training/saver_test.py
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/tools/ci_build/ci_sanity.sh
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
tensorflow/workspace.bzl
浏览文件 @
b3d5ec90
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录