diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index dff2f8f022a06b6a892da801d41e2f6c7d6edc8b..31db3a0dfe945cfb47c9f7560d8278f8aefa5e58 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -174,6 +174,27 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } +XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { + ComputationBuilder b(client_, TestName()); + auto r1_0 = b.ConstantR1({1000, 2000}); + auto r1_1 = b.ConstantR1({100, 200}); + auto r1_2 = b.ConstantR1({10, 20}); + auto r0 = b.ConstantR0(3); + auto r3 = b.Broadcast(r0, {2, 2, 2}); + for (int i = 0; i < 3; ++i) { + r3 = b.Add(r1_0, r3, {0}); + r3 = b.Add(r3, r1_1, {1}); + r3 = b.Add(r1_2, r3, {2}); + } + r3 = b.Mul(r3, b.ConstantR0(-1)); + + auto expected = LiteralUtil::CreateR3( + {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}}, + {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}}); + + ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); +} + XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { // Binary dimension broadcasting of the smaller lhs ([2, 2] up to [2, 2, 2]) // results in a shape incompatible with the lhs [2, 3, 1].