Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
7c93af2b
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
7c93af2b
编写于
7月 16, 2019
作者:
T
TensorFlower Gardener
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #30170 from DavidNorman:allow-disable-dot-to-multiply
PiperOrigin-RevId: 258472358
上级
579a5c94
5da31a1a
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
37 addition
and
1 deletion
+37
-1
tensorflow/compiler/xla/service/algebraic_simplifier.cc
tensorflow/compiler/xla/service/algebraic_simplifier.cc
+2
-1
tensorflow/compiler/xla/service/algebraic_simplifier.h
tensorflow/compiler/xla/service/algebraic_simplifier.h
+10
-0
tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+25
-0
未找到文件。
tensorflow/compiler/xla/service/algebraic_simplifier.cc
浏览文件 @
7c93af2b
...
...
@@ -1705,7 +1705,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
// If there are no contracting dimensions, a dot can be rewritten as
// mul(broadcast(transpose(x)),broadcast(transpose(y)))
if
(
dot
->
dot_dimension_numbers
().
lhs_contracting_dimensions_size
()
==
0
)
{
if
(
options_
.
enable_dot_to_multiply_rewrite
()
&&
dot
->
dot_dimension_numbers
().
lhs_contracting_dimensions_size
()
==
0
)
{
TF_ASSIGN_OR_RETURN
(
HloInstruction
*
new_lhs
,
NormalizeDotOperandToBatchMajorAndContractingMinor
(
...
...
tensorflow/compiler/xla/service/algebraic_simplifier.h
浏览文件 @
7c93af2b
...
...
@@ -63,6 +63,15 @@ class AlgebraicSimplifierOptions {
return
enable_dot_strength_reduction_
;
}
// Enable dot->multiple rewrite for dot as an outer-product
void
set_enable_dot_to_multiply_rewrite
(
bool
enable_dot_to_multiply_rewrite
)
{
enable_dot_to_multiply_rewrite_
=
enable_dot_to_multiply_rewrite
;
}
bool
enable_dot_to_multiply_rewrite
()
const
{
return
enable_dot_to_multiply_rewrite_
;
}
// Enable convolution simplification on platforms where it is profitable.
void
set_enable_conv_simplification
(
bool
enable_conv_simplification
)
{
enable_conv_simplification_
=
enable_conv_simplification
;
...
...
@@ -87,6 +96,7 @@ class AlgebraicSimplifierOptions {
ReshapeIsBitcastCallback
reshape_is_bitcast_callback_
;
bool
is_layout_sensitive_
{
false
};
bool
enable_dot_strength_reduction_
{
true
};
bool
enable_dot_to_multiply_rewrite_
{
true
};
bool
enable_conv_simplification_
{
true
};
bool
enable_window_reduce_to_reduce_replacement_
{
true
};
};
...
...
tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
浏览文件 @
7c93af2b
...
...
@@ -5457,6 +5457,31 @@ TEST_F(AlgebraicSimplifierTest, CompareSame) {
GmockMatch
(
m
::
Broadcast
(
m
::
ConstantScalar
(
true
))));
}
TEST_F
(
AlgebraicSimplifierTest
,
CanDisableDotToMultiplyRewrite
)
{
// Some backends may have better performance by treating an outer product as a
// Dot, rather than a broadcast Multiply
const
char
*
kModuleStr
=
R"(
HloModule m
test {
param1 = f32[64] parameter(0)
param2 = f32[64] parameter(1)
ROOT compare = f32[64, 64] dot(param1, param2),
lhs_contracting_dims={}, rhs_contracting_dims={}
})"
;
// Verify that the default is to re-write
TF_ASSERT_OK_AND_ASSIGN
(
auto
m1
,
ParseAndReturnVerifiedModule
(
kModuleStr
));
ASSERT_TRUE
(
AlgebraicSimplifier
(
default_options_
).
Run
(
m1
.
get
()).
ValueOrDie
());
EXPECT_THAT
(
m1
->
entry_computation
()
->
root_instruction
(),
GmockMatch
(
m
::
Multiply
(
m
::
Op
(),
m
::
Op
())));
// Verify that we can disable the re-write
AlgebraicSimplifierOptions
opts
=
default_options_
;
opts
.
set_enable_dot_to_multiply_rewrite
(
false
);
TF_ASSERT_OK_AND_ASSIGN
(
auto
m2
,
ParseAndReturnVerifiedModule
(
kModuleStr
));
ASSERT_FALSE
(
AlgebraicSimplifier
(
opts
).
Run
(
m2
.
get
()).
ValueOrDie
());
}
TEST_F
(
AlgebraicSimplifierTest
,
RemainderOfIota
)
{
const
char
*
kModuleStr
=
R"(
HloModule m
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录