Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
0cb2a73b
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
0cb2a73b
编写于
8月 12, 2019
作者:
Y
Yunxing Dai
提交者:
TensorFlower Gardener
8月 12, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XLA] [DynamicPadder] Support sort op.
PiperOrigin-RevId: 262951106
上级
40897309
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
55 addition
and
1 deletion
+55
-1
tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/BUILD
+2
-1
tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
...rflow/compiler/xla/service/dynamic_dimension_inference.cc
+20
-0
tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc
.../compiler/xla/service/dynamic_dimension_inference_test.cc
+32
-0
tensorflow/compiler/xla/service/dynamic_padder.cc
tensorflow/compiler/xla/service/dynamic_padder.cc
+1
-0
未找到文件。
tensorflow/compiler/xla/service/BUILD
浏览文件 @
0cb2a73b
...
...
@@ -2173,13 +2173,14 @@ cc_library(
hdrs
=
[
"dynamic_dimension_inference.h"
],
deps
=
[
":hlo"
,
":hlo_casting_utils"
,
":while_util"
,
"//tensorflow/compiler/xla:literal_util"
,
"//tensorflow/compiler/xla:status"
,
"//tensorflow/compiler/xla:statusor"
,
"//tensorflow/compiler/xla:types"
,
"//tensorflow/compiler/xla:window_util"
,
"//tensorflow/core
:lib
"
,
"//tensorflow/core
/platform:macros
"
,
"@com_google_absl//absl/container:flat_hash_map"
,
"@com_google_absl//absl/types:span"
,
],
...
...
tensorflow/compiler/xla/service/dynamic_dimension_inference.cc
浏览文件 @
0cb2a73b
...
...
@@ -17,8 +17,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/window_util.h"
...
...
@@ -53,6 +55,8 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault {
Status
HandleReshape
(
HloInstruction
*
hlo
)
override
;
Status
HandleSort
(
HloInstruction
*
hlo
)
override
;
Status
HandlePad
(
HloInstruction
*
hlo
)
override
;
Status
HandleBroadcast
(
HloInstruction
*
hlo
)
override
;
...
...
@@ -161,6 +165,22 @@ Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) {
});
}
Status
DynamicDimensionInferenceVisitor
::
HandleSort
(
HloInstruction
*
hlo
)
{
return
ForEachOperandDynamicDimension
(
hlo
,
[
&
](
HloInstruction
*
operand
,
ShapeIndex
index
,
int64
dynamic_dimension
,
int64
operand_index
,
HloInstruction
*
dynamic_size
,
DimensionConstraint
constraint
)
{
int64
sort_dimension
=
Cast
<
HloSortInstruction
>
(
hlo
)
->
sort_dimension
();
if
(
sort_dimension
==
dynamic_dimension
)
{
return
Unimplemented
(
"Dynamic dimension on sorting dimension is not supported"
);
}
parent_
->
SetDynamicSize
(
hlo
,
{},
dynamic_dimension
,
dynamic_size
,
constraint
);
return
Status
::
OK
();
});
}
Status
DynamicDimensionInferenceVisitor
::
HandlePad
(
HloInstruction
*
hlo
)
{
return
ForEachOperandDynamicDimension
(
hlo
,
[
&
](
HloInstruction
*
operand
,
ShapeIndex
index
,
int64
dimension
,
...
...
tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc
浏览文件 @
0cb2a73b
...
...
@@ -912,6 +912,38 @@ TEST_F(DynamicDimensionInferenceTest, DynamicSliceTest) {
EXPECT_EQ
(
inference_
->
GetDynamicSize
(
slice
,
{},
0
),
size_param
);
}
TEST_F
(
DynamicDimensionInferenceTest
,
SortTest
)
{
auto
builder
=
HloComputation
::
Builder
(
TestName
());
auto
data_param
=
builder
.
AddInstruction
(
HloInstruction
::
CreateParameter
(
0
,
ShapeUtil
::
MakeShape
(
F32
,
{
5
,
7
}),
"data_param"
));
auto
size_param
=
builder
.
AddInstruction
(
HloInstruction
::
CreateParameter
(
1
,
scalar_shape_
,
"size_param"
));
auto
compare_builder
=
HloComputation
::
Builder
(
"condition"
);
compare_builder
.
AddInstruction
(
HloInstruction
::
CreateParameter
(
0
,
ShapeUtil
::
MakeShape
(
F32
,
{}),
"param1"
));
compare_builder
.
AddInstruction
(
HloInstruction
::
CreateParameter
(
1
,
ShapeUtil
::
MakeShape
(
F32
,
{}),
"param2"
));
compare_builder
.
AddInstruction
(
HloInstruction
::
CreateConstant
(
LiteralUtil
::
CreateR0
<
bool
>
(
false
)));
HloComputation
*
compare
=
module_
->
AddEmbeddedComputation
(
compare_builder
.
Build
());
auto
*
sort
=
builder
.
AddInstruction
(
HloInstruction
::
CreateSort
(
ShapeUtil
::
MakeShape
(
F32
,
{
5
,
7
}),
1
,
{
data_param
},
compare
,
/*is_stable=*/
false
));
module_
->
AddEntryComputation
(
builder
.
Build
());
// Set up dynamic parameter binding.
TF_CHECK_OK
(
module_
->
dynamic_parameter_binding
().
Bind
(
DynamicParameterBinding
::
DynamicParameter
{
1
,
{}},
DynamicParameterBinding
::
DynamicDimension
{
0
,
{},
0
}));
TF_ASSERT_OK
(
RunInference
());
EXPECT_EQ
(
inference_
->
GetDynamicSize
(
sort
,
{},
0
),
size_param
);
}
TEST_F
(
DynamicDimensionInferenceTest
,
DynamicSliceSingleElementTest
)
{
// Slicing out a single element from a dynamic dimension terminates the
// dynamic dimension.
...
...
tensorflow/compiler/xla/service/dynamic_padder.cc
浏览文件 @
0cb2a73b
...
...
@@ -90,6 +90,7 @@ StatusOr<HloInstruction*> ChooseIdentityValue(HloInstruction* inst,
case
HloOpcode
::
kAllReduce
:
case
HloOpcode
::
kBroadcast
:
case
HloOpcode
::
kTranspose
:
case
HloOpcode
::
kSort
:
case
HloOpcode
::
kSlice
:
return
nullptr
;
default:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录