diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc index e958b0f1eb65ac8191182ad92cef4efa526890ef..9f366ec97cd43048bbb63e8ebac4f565f33ad97c 100644 --- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc @@ -536,15 +536,13 @@ StatusOr PartitionBaseCase( } if (lhs_contracting_partitions == rhs_contracting_partitions && lhs_contracting_partitions == num_partitions && - output_sharding_dim > -1) { - if (output_lhs_non_contracting_partitions == num_partitions && - ShapeSizeInBytes(rhs.base_shape()) >= - options.threshold_for_windowed_einsum_mib * 1024 * 1024) { + output_sharding_dim > -1 && + ShapeSizeInBytes(output_base_shape) >= + options.threshold_for_windowed_einsum_mib * 1024 * 1024) { + if (output_lhs_non_contracting_partitions == num_partitions) { return emit_windowed_dot_general(0, 1, false, false, true); } - if (output_rhs_non_contracting_partitions == num_partitions && - ShapeSizeInBytes(lhs.base_shape()) >= - options.threshold_for_windowed_einsum_mib * 1024 * 1024) { + if (output_rhs_non_contracting_partitions == num_partitions) { return emit_windowed_dot_general(1, 0, false, false, true); } } diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 5247ef88abbdb559b95ecb11c7bdebd4518f234e..6144a12bb59f194310d8258e792f309a5e192c2d 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -3681,12 +3681,12 @@ TEST_F(SpmdPartitioningTest, HloModule module ENTRY entry { - %lhs = f32[32,25,64,128] parameter(0) - %lhs.copy = f32[32,25,64,128] copy(%lhs), sharding={devices=[1,1,4,1]0,1,2,3} - %rhs = f32[32,39296,64,128] parameter(1) - %rhs.copy = f32[32,39296,64,128] copy(%rhs), + %lhs = f32[320,25,64,128] parameter(0) + %lhs.copy = f32[320,25,64,128] copy(%lhs), sharding={devices=[1,1,4,1]0,1,2,3} + %rhs = f32[320,39296,64,128] parameter(1) + %rhs.copy = f32[320,39296,64,128] copy(%rhs), sharding={devices=[1,1,4,1]0,1,2,3} - ROOT %dot = f32[32,25,39296] dot(%lhs.copy, %rhs.copy), + ROOT %dot = f32[320,25,39296] dot(%lhs.copy, %rhs.copy), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3}, sharding={devices=[1,4,1]0,1,2,3} @@ -3700,14 +3700,14 @@ ENTRY entry { auto lhs = AllOf( op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), op::Constant(), op::Reshape(), op::Constant())), - op::Shape("f32[32,25,16,128]")); + op::Shape("f32[320,25,16,128]")); auto rhs = AllOf( op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(), op::Constant(), op::Reshape(), op::Constant())), - op::Shape("f32[32,39296,16,128]")); + op::Shape("f32[320,39296,16,128]")); EXPECT_THAT(root, AllOf(op::GetTupleElement(op::While(op::Tuple( lhs, rhs, op::Broadcast(), op::Constant()))), - op::Shape("f32[32,7,39296]"))); + op::Shape("f32[320,7,39296]"))); auto while_loop = root->operand(0); // Check loop condition. @@ -3721,11 +3721,11 @@ ENTRY entry { AllOf(op::DynamicSlice( op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()), op::Constant(), op::Multiply(), op::Constant(), op::Constant()), - op::Shape("f32[32,7,16,128]")); + op::Shape("f32[320,7,16,128]")); auto partial_output = AllOf(op::Add(op::GetTupleElement(op::Parameter(0)), op::Dot(ds, op::GetTupleElement(op::Parameter(0)))), - op::Shape("f32[32,7,39296]")); + op::Shape("f32[320,7,39296]")); auto window = op::Conditional(op::Compare(next_i, op::Constant()), partial_output, partial_output); EXPECT_THAT(while_loop->while_body()->root_instruction(),