From 5b76abd4497cd4a121c8f64b8db107f8a1de6388 Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Fri, 30 Oct 2020 22:10:01 -0700 Subject: [PATCH] [XLA:SPMD] Fix contracting dim loop einsum threshold PiperOrigin-RevId: 339989040 Change-Id: I46eaffb4e4b28c168b5f0182c4c0bb946d6eefe8 --- .../compiler/xla/service/spmd/dot_handler.cc | 12 +++++------ .../xla/service/spmd/spmd_partitioner_test.cc | 20 +++++++++---------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc index e958b0f1eb6..9f366ec97cd 100644 --- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc @@ -536,15 +536,13 @@ StatusOr PartitionBaseCase( } if (lhs_contracting_partitions == rhs_contracting_partitions && lhs_contracting_partitions == num_partitions && - output_sharding_dim > -1) { - if (output_lhs_non_contracting_partitions == num_partitions && - ShapeSizeInBytes(rhs.base_shape()) >= - options.threshold_for_windowed_einsum_mib * 1024 * 1024) { + output_sharding_dim > -1 && + ShapeSizeInBytes(output_base_shape) >= + options.threshold_for_windowed_einsum_mib * 1024 * 1024) { + if (output_lhs_non_contracting_partitions == num_partitions) { return emit_windowed_dot_general(0, 1, false, false, true); } - if (output_rhs_non_contracting_partitions == num_partitions && - ShapeSizeInBytes(lhs.base_shape()) >= - options.threshold_for_windowed_einsum_mib * 1024 * 1024) { + if (output_rhs_non_contracting_partitions == num_partitions) { return emit_windowed_dot_general(1, 0, false, false, true); } } diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 5247ef88abb..6144a12bb59 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -3681,12 +3681,12 @@ TEST_F(SpmdPartitioningTest, HloModule module ENTRY entry { - %lhs = f32[32,25,64,128] parameter(0) - %lhs.copy = f32[32,25,64,128] copy(%lhs), sharding={devices=[1,1,4,1]0,1,2,3} - %rhs = f32[32,39296,64,128] parameter(1) - %rhs.copy = f32[32,39296,64,128] copy(%rhs), + %lhs = f32[320,25,64,128] parameter(0) + %lhs.copy = f32[320,25,64,128] copy(%lhs), sharding={devices=[1,1,4,1]0,1,2,3} + %rhs = f32[320,39296,64,128] parameter(1) + %rhs.copy = f32[320,39296,64,128] copy(%rhs), sharding={devices=[1,1,4,1]0,1,2,3} - ROOT %dot = f32[32,25,39296] dot(%lhs.copy, %rhs.copy), + ROOT %dot = f32[320,25,39296] dot(%lhs.copy, %rhs.copy), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, sharding={devices=[1,4,1]0,1,2,3} @@ -3700,14 +3700,14 @@ ENTRY entry { auto lhs = AllOf( op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Constant(), op::Reshape(), op::Constant())), - op::Shape("f32[32,25,16,128]")); + op::Shape("f32[320,25,16,128]")); auto rhs = AllOf( op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Constant(), op::Reshape(), op::Constant())), - op::Shape("f32[32,39296,16,128]")); + op::Shape("f32[320,39296,16,128]")); EXPECT_THAT(root, AllOf(op::GetTupleElement(op::While(op::Tuple( lhs, rhs, op::Broadcast(), op::Constant()))), - op::Shape("f32[32,7,39296]"))); + op::Shape("f32[320,7,39296]"))); auto while_loop = root->operand(0); // Check loop condition. @@ -3721,11 +3721,11 @@ ENTRY entry { AllOf(op::DynamicSlice( op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()), op::Constant(), op::Multiply(), op::Constant(), op::Constant()), - op::Shape("f32[32,7,16,128]")); + op::Shape("f32[320,7,16,128]")); auto partial_output = AllOf(op::Add(op::GetTupleElement(op::Parameter(0)), op::Dot(ds, op::GetTupleElement(op::Parameter(0)))), - op::Shape("f32[32,7,39296]")); + op::Shape("f32[320,7,39296]")); auto window = op::Conditional(op::Compare(next_i, op::Constant()), partial_output, partial_output); EXPECT_THAT(while_loop->while_body()->root_instruction(), -- GitLab