提交 7c93af2b 编写于 作者: T TensorFlower Gardener

Merge pull request #30170 from DavidNorman:allow-disable-dot-to-multiply

PiperOrigin-RevId: 258472358
......@@ -1705,7 +1705,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
// If there are no contracting dimensions, a dot can be rewritten as
// mul(broadcast(transpose(x)),broadcast(transpose(y)))
if (dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) {
if (options_.enable_dot_to_multiply_rewrite() &&
dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == 0) {
TF_ASSIGN_OR_RETURN(
HloInstruction * new_lhs,
NormalizeDotOperandToBatchMajorAndContractingMinor(
......
......@@ -63,6 +63,15 @@ class AlgebraicSimplifierOptions {
return enable_dot_strength_reduction_;
}
// Enable dot->multiple rewrite for dot as an outer-product
void set_enable_dot_to_multiply_rewrite(bool enable_dot_to_multiply_rewrite) {
enable_dot_to_multiply_rewrite_ = enable_dot_to_multiply_rewrite;
}
bool enable_dot_to_multiply_rewrite() const {
return enable_dot_to_multiply_rewrite_;
}
// Enable convolution simplification on platforms where it is profitable.
void set_enable_conv_simplification(bool enable_conv_simplification) {
enable_conv_simplification_ = enable_conv_simplification;
......@@ -87,6 +96,7 @@ class AlgebraicSimplifierOptions {
ReshapeIsBitcastCallback reshape_is_bitcast_callback_;
bool is_layout_sensitive_{false};
bool enable_dot_strength_reduction_{true};
bool enable_dot_to_multiply_rewrite_{true};
bool enable_conv_simplification_{true};
bool enable_window_reduce_to_reduce_replacement_{true};
};
......
......@@ -5457,6 +5457,31 @@ TEST_F(AlgebraicSimplifierTest, CompareSame) {
GmockMatch(m::Broadcast(m::ConstantScalar(true))));
}
TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) {
// Some backends may have better performance by treating an outer product as a
// Dot, rather than a broadcast Multiply
const char* kModuleStr = R"(
HloModule m
test {
param1 = f32[64] parameter(0)
param2 = f32[64] parameter(1)
ROOT compare = f32[64, 64] dot(param1, param2),
lhs_contracting_dims={}, rhs_contracting_dims={}
})";
// Verify that the default is to re-write
TF_ASSERT_OK_AND_ASSIGN(auto m1, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m1.get()).ValueOrDie());
EXPECT_THAT(m1->entry_computation()->root_instruction(),
GmockMatch(m::Multiply(m::Op(), m::Op())));
// Verify that we can disable the re-write
AlgebraicSimplifierOptions opts = default_options_;
opts.set_enable_dot_to_multiply_rewrite(false);
TF_ASSERT_OK_AND_ASSIGN(auto m2, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_FALSE(AlgebraicSimplifier(opts).Run(m2.get()).ValueOrDie());
}
TEST_F(AlgebraicSimplifierTest, RemainderOfIota) {
const char* kModuleStr = R"(
HloModule m
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册