Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
曾经的那一瞬间
Tensorflow
提交
95abdf8d
T
Tensorflow
项目概览
曾经的那一瞬间
/
Tensorflow
10 个月 前同步成功
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
95abdf8d
编写于
8月 29, 2023
作者:
K
Kevin Chen
提交者:
TensorFlower Gardener
8月 29, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rollback of: Update TopkRewriter to handle tensors with rank > 2
Broke argsort on gpu. PiperOrigin-RevId: 561136138
上级
26612c86
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
74 addition
and
140 deletion
+74
-140
tensorflow/compiler/tests/sort_ops_test.py
tensorflow/compiler/tests/sort_ops_test.py
+49
-33
tensorflow/compiler/xla/service/topk_rewriter.cc
tensorflow/compiler/xla/service/topk_rewriter.cc
+25
-71
tensorflow/compiler/xla/service/topk_rewriter_test.cc
tensorflow/compiler/xla/service/topk_rewriter_test.cc
+0
-36
未找到文件。
tensorflow/compiler/tests/sort_ops_test.py
浏览文件 @
95abdf8d
...
...
@@ -237,23 +237,12 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
self
.
_assertOpOutputMatchesExpected
(
wrap_sort
,
inputs
,
expected
=
inputs
)
@
parameterized
.
product
(
dtype
=
[
dtypes
.
bfloat16
.
as_numpy_dtype
,
np
.
float16
,
np
.
float32
,
np
.
float64
,
np
.
int32
,
np
.
uint32
,
np
.
int64
,
np
.
uint64
,
np
.
uint8
,
np
.
int8
,
],
rank
=
[
1
,
2
,
3
],
)
def
testTopK
(
self
,
dtype
,
rank
):
if
dtype
in
self
.
numeric_types
:
def
testTopK
(
self
):
supported_types
=
set
([
dtypes
.
bfloat16
.
as_numpy_dtype
,
np
.
float16
,
np
.
float32
,
np
.
float64
,
np
.
int32
,
np
.
uint32
,
np
.
int64
,
np
.
uint64
,
np
.
uint8
,
np
.
int8
,
])
for
dtype
in
supported_types
.
intersection
(
self
.
numeric_types
):
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
# after conversion to bfloat16, so the possible resulting index array is
# no longer unique.
...
...
@@ -266,26 +255,53 @@ class XlaSortOpTest(xla_test.XLATestCase, parameterized.TestCase):
else
:
array_size
=
200
*
1000
k_options
=
[
0
,
1
,
2
,
10
,
20
,
100
,
1000
,
200
*
1000
]
for
x
in
[
np
.
arange
(
array_size
)]:
np
.
random
.
shuffle
(
x
)
for
k
in
k_options
:
indices
=
x
.
argsort
()[::
-
1
][:
k
]
# Tile array to tensor of specified rank, then shuffle along the last dim
x
=
np
.
arange
(
array_size
)
x
=
np
.
tile
(
x
,
(
2
,)
*
(
rank
-
1
)
+
(
1
,))
np
.
apply_along_axis
(
np
.
random
.
shuffle
,
-
1
,
x
)
def
topk
(
v
,
k
=
k
):
return
nn_ops
.
top_k
(
v
,
k
=
k
,
sorted
=
True
)
sorted_indices
=
x
.
argsort
(
axis
=-
1
)[...,
::
-
1
]
sorted_values
=
np
.
sort
(
x
,
axis
=-
1
)[...,
::
-
1
]
for
k
in
k_options
:
indices
=
sorted_indices
[...,
:
k
]
expected
=
sorted_values
[...,
:
k
]
self
.
_assertOpOutputMatchesExpected
(
topk
,
[
x
.
astype
(
dtype
)],
expected
=
[
x
[
indices
].
astype
(
dtype
),
indices
])
def
topk
(
v
,
k
=
k
):
return
nn_ops
.
top_k
(
v
,
k
=
k
,
sorted
=
True
)
@
parameterized
.
named_parameters
(
(
"HalfPrecision"
,
dtypes
.
bfloat16
.
as_numpy_dtype
),
(
"HalfFloatPrecision"
,
np
.
float16
),
(
"SinglePrecision"
,
np
.
float32
),
(
"DoublePrecision"
,
np
.
float64
),
(
"Int32"
,
np
.
int32
),
(
"UnsignedInt32"
,
np
.
uint32
),
(
"Int64"
,
np
.
int64
),
(
"UnsignedInt64"
,
np
.
uint64
),
)
def
testTopK2D
(
self
,
dtype
):
if
dtype
in
self
.
numeric_types
:
# Use small input size for bfloat16. Otherwise, we'll get duplicate values
# after conversion to bfloat16, so the possible resulting index array is
# no longer unique.
if
dtype
in
(
dtypes
.
bfloat16
.
as_numpy_dtype
,
np
.
float16
):
array_size
=
10
k_options
=
[
0
,
1
,
2
,
10
]
else
:
array_size
=
200
*
1000
k_options
=
[
0
,
1
,
2
,
10
,
20
,
100
,
1000
,
200
*
1000
]
batch
=
16
for
x
in
[
np
.
arange
(
batch
*
array_size
)]:
np
.
random
.
shuffle
(
x
)
x
=
np
.
reshape
(
x
,
[
batch
,
array_size
])
for
k
in
k_options
:
indices
=
x
.
argsort
(
axis
=
1
)[::,
-
1
:
-
k
-
1
:
-
1
]
expected
=
np
.
sort
(
x
,
axis
=
1
)[::,
-
1
:
-
k
-
1
:
-
1
]
self
.
_assertOpOutputMatchesExpected
(
topk
,
[
x
.
astype
(
dtype
)],
expected
=
[
expected
.
astype
(
dtype
),
indices
],
)
def
topk
(
v
,
k
=
k
):
return
nn_ops
.
top_k
(
v
,
k
=
k
,
sorted
=
True
)
self
.
_assertOpOutputMatchesExpected
(
topk
,
[
x
.
astype
(
dtype
)],
expected
=
[
expected
.
astype
(
dtype
),
indices
])
def
testTopKZeros
(
self
):
"""Tests that positive and negative zeros sort correctly."""
...
...
tensorflow/compiler/xla/service/topk_rewriter.cc
浏览文件 @
95abdf8d
...
...
@@ -196,6 +196,8 @@ std::optional<int64_t> TopkRewriter::SortIsInTopK(HloInstruction* inst) {
return
std
::
nullopt
;
}
const
int64_t
sort_dim
=
sort
->
sort_dimension
();
const
int64_t
batch_dim
=
sort_dim
==
1
?
0
:
1
;
const
bool
has_batch
=
data
->
shape
().
rank
()
==
2
;
bool
supported
=
true
;
std
::
optional
<
int64_t
>
k
;
...
...
@@ -220,15 +222,10 @@ std::optional<int64_t> TopkRewriter::SortIsInTopK(HloInstruction* inst) {
supported
=
false
;
break
;
}
for
(
int64_t
i
=
0
;
i
<
slice
->
slice_limits
().
size
();
++
i
)
{
if
(
i
!=
sort_dim
&&
slice
->
slice_limits
(
i
)
!=
slice
->
operand
(
0
)
->
shape
().
dimensions
(
i
))
{
// Slicing along a non-sort dimension isn't supported.
supported
=
false
;
break
;
}
}
if
(
!
supported
)
{
if
(
has_batch
&&
slice
->
slice_limits
(
batch_dim
)
!=
slice
->
operand
(
0
)
->
shape
().
dimensions
(
batch_dim
))
{
// Slicing along the batch dimension isn't supported.
supported
=
false
;
break
;
}
if
(
k
==
std
::
nullopt
)
{
...
...
@@ -260,57 +257,29 @@ StatusOr<bool> TopkRewriter::TransformToCustomCall(
HloSortInstruction
*
sort
=
DynCast
<
HloSortInstruction
>
(
inst
);
HloInstruction
*
data
=
sort
->
mutable_operand
(
0
);
const
PrimitiveType
element_type
=
data
->
shape
().
element_type
();
const
Shape
data_shape
=
data
->
shape
();
if
(
element_type
!=
F32
&&
element_type
!=
BF16
)
{
if
((
data
->
shape
().
rank
()
!=
1
&&
data
->
shape
().
rank
()
!=
2
)
||
(
element_type
!=
F32
&&
element_type
!=
BF16
))
{
continue
;
}
// Sort dimension must be the first or last dimension.
const
int64_t
sort_dim
=
sort
->
sort_dimension
();
if
(
sort_dim
!=
0
&&
sort_dim
!=
data_shape
.
rank
()
-
1
)
{
continue
;
}
const
int64_t
batch_dim
=
sort_dim
==
1
?
0
:
1
;
const
bool
has_batch
=
data
->
shape
().
rank
()
==
2
;
// Profitability check.
if
(
!
is_profitable_to_convert_
(
sort
,
*
k
))
{
continue
;
}
HloInstruction
*
input
=
data
;
const
bool
has_batch
=
data_shape
.
rank
()
>=
2
;
const
int64_t
input_size
=
data_shape
.
dimensions
(
sort_dim
);
int64_t
batch_size
=
1
;
Shape
topk_input_shape
;
if
(
has_batch
)
{
// The TopK custom call expects either a 1d tensor or a 2d tensor with
// the last dimension being the sort dimension. An input with rank > 2
// is reshaped into a 2d tensor by combining non-sort dimensions into a
// single batch dimension. The original non-sort dimensions are
// restored for the outputs with another reshape after the custom call.
batch_size
=
ShapeUtil
::
ElementsIn
(
data_shape
)
/
data_shape
.
dimensions
(
sort_dim
);
topk_input_shape
=
ShapeUtil
::
MakeShape
(
element_type
,
{
batch_size
,
input_size
});
if
(
data_shape
.
rank
()
>
2
)
{
// Reshape to 2d.
input
=
comp
->
AddInstruction
(
HloInstruction
::
CreateReshape
(
sort_dim
==
0
?
ShapeUtil
::
MakeShape
(
element_type
,
{
input_size
,
batch_size
})
:
ShapeUtil
::
MakeShape
(
element_type
,
{
batch_size
,
input_size
}),
input
));
}
if
(
sort_dim
==
0
)
{
// Transpose for the custom call when sorting the first dimension.
input
=
comp
->
AddInstruction
(
HloInstruction
::
CreateTranspose
(
topk_input_shape
,
input
,
{
1
,
0
}));
}
}
else
{
topk_input_shape
=
data_shape
;
const
int64_t
batch_size
=
has_batch
?
sort
->
operand
(
0
)
->
shape
().
dimensions
(
batch_dim
)
:
1
;
const
int64_t
input_size
=
sort
->
operand
(
0
)
->
shape
().
dimensions
(
sort_dim
);
HloInstruction
*
input
=
sort
->
mutable_operand
(
0
);
if
(
has_batch
&&
sort_dim
==
0
)
{
input
=
comp
->
AddInstruction
(
HloInstruction
::
CreateTranspose
(
ShapeUtil
::
MakeShape
(
element_type
,
{
batch_size
,
input_size
}),
input
,
{
1
,
0
}));
}
Shape
topk_shape
=
...
...
@@ -331,28 +300,13 @@ StatusOr<bool> TopkRewriter::TransformToCustomCall(
comp
->
AddInstruction
(
HloInstruction
::
CreateGetTupleElement
(
topk
->
shape
().
tuple_shapes
(
1
),
topk
,
1
));
if
(
has_batch
)
{
if
(
sort_dim
==
0
)
{
// Transpose back.
value_gte
=
comp
->
AddInstruction
(
HloInstruction
::
CreateTranspose
(
ShapeUtil
::
MakeShape
(
element_type
,
{
k
.
value
(),
batch_size
}),
value_gte
,
{
1
,
0
}));
index_gte
=
comp
->
AddInstruction
(
HloInstruction
::
CreateTranspose
(
ShapeUtil
::
MakeShape
(
S32
,
{
k
.
value
(),
batch_size
}),
index_gte
,
{
1
,
0
}));
}
if
(
data_shape
.
rank
()
>
2
)
{
// Reshape back.
Shape
value_shape
=
data_shape
;
value_shape
.
set_dimensions
(
sort_dim
,
k
.
value
());
value_gte
=
comp
->
AddInstruction
(
HloInstruction
::
CreateReshape
(
value_shape
,
value_gte
));
Shape
index_shape
=
ShapeUtil
::
MakeShape
(
S32
,
value_shape
.
dimensions
());
index_gte
=
comp
->
AddInstruction
(
HloInstruction
::
CreateReshape
(
index_shape
,
index_gte
));
}
if
(
has_batch
&&
sort_dim
==
0
)
{
value_gte
=
comp
->
AddInstruction
(
HloInstruction
::
CreateTranspose
(
ShapeUtil
::
MakeShape
(
element_type
,
{
k
.
value
(),
batch_size
}),
value_gte
,
{
1
,
0
}));
index_gte
=
comp
->
AddInstruction
(
HloInstruction
::
CreateTranspose
(
ShapeUtil
::
MakeShape
(
S32
,
{
k
.
value
(),
batch_size
}),
index_gte
,
{
1
,
0
}));
}
for
(
HloInstruction
*
user
:
sort
->
users
())
{
...
...
tensorflow/compiler/xla/service/topk_rewriter_test.cc
浏览文件 @
95abdf8d
...
...
@@ -326,42 +326,6 @@ ENTRY cluster {
EXPECT_THAT
(
cc
->
custom_call_target
(),
"TopK"
);
}
TEST_F
(
TopkRewriterTest
,
RewriteReshape
)
{
const
std
::
string
hlo_string
=
R"(
HloModule module
)"
+
getComparator
()
+
R"(
ENTRY cluster {
%arg_tuple.1 = f32[3,8,1234567] parameter(0)
%iota.4 = s32[3,8,1234567] iota(), iota_dimension=2
%sort.27 = (f32[3,8,1234567], s32[3,8,1234567]) sort(%arg_tuple.1, %iota.4),
dimensions={2}, is_stable=true, to_apply=%compare
%get-tuple-element.28 = f32[3, 8,1234567] get-tuple-element(%sort.27), index=0
%slice.29 = f32[3,8,5] slice(%get-tuple-element.28), slice={[0:3], [0:8], [0:5]}
%get-tuple-element.30 = s32[3,8,1234567] get-tuple-element(%sort.27), index=1
%slice.31 = s32[3,8,5] slice(%get-tuple-element.30), slice={[0:3], [0:8], [0:5]}
ROOT %tuple.32 = (f32[3,8,5], s32[3,8,5]) tuple(%slice.29, %slice.31)
})"
;
TF_ASSERT_OK_AND_ASSIGN
(
auto
module
,
ParseAndReturnVerifiedModule
(
hlo_string
));
TopkRewriter
rewriter
(
[](
const
HloSortInstruction
*
,
int64_t
)
{
return
true
;
});
TF_ASSERT_OK_AND_ASSIGN
(
bool
changed
,
rewriter
.
Run
(
module
.
get
()));
TF_ASSERT_OK
(
HloDCE
().
Run
(
module
.
get
()).
status
());
EXPECT_TRUE
(
changed
);
EXPECT_THAT
(
module
->
entry_computation
()
->
root_instruction
(),
GmockMatch
(
m
::
Tuple
(
m
::
Reshape
(
m
::
GetTupleElement
(
m
::
CustomCall
(
m
::
Reshape
(
m
::
Parameter
(
0
))),
0
)),
m
::
Reshape
(
m
::
GetTupleElement
(
m
::
CustomCall
(
m
::
Reshape
(
m
::
Parameter
(
0
))),
1
)))));
const
HloInstruction
*
cc
=
module
->
entry_computation
()
->
root_instruction
()
->
operand
(
0
)
->
operand
(
0
)
->
operand
(
0
);
EXPECT_THAT
(
cc
->
custom_call_target
(),
"TopK"
);
}
TEST_F
(
TopkRewriterTest
,
RewriteNoIota
)
{
const
std
::
string
hlo_string
=
R"(
HloModule module
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录