diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 31db3a0dfe945cfb47c9f7560d8278f8aefa5e58..63744afb4ea72006262aad74e9b8d75a09b107e6 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -48,6 +48,19 @@ XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) { ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } +XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) { + ComputationBuilder b(client_, TestName()); + ComputationDataHandle src; + std::unique_ptr param_data = + CreateR0Parameter(2.25f, /*parameter_number=*/0, /*name=*/"src", + /*builder=*/&b, /*data_handle=*/&src); + + b.Broadcast(src, {2, 3}); + Array2D expected(2, 3, 2.25); + ComputeAndCompareR2(&b, expected, {param_data.get()}, + ErrorSpec(0.0001)); +} + XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) { ComputationBuilder b(client_, TestName()); b.Broadcast(b.ConstantR0(2.25), {2, 0}); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 026f487c2df2f91b369921a07acbc83c581ef8f0..92c08d03acdd5653484c09d1d28e401acc328310 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -216,6 +216,16 @@ class ClientLibraryTestBase : public ::testing::Test { const int rows, const int cols, const int rows_padded, const int cols_padded); + // Create a parameter instruction that wraps a given value and then stores + // into "data_handle" the global handle for that parameter. + // + // "parameter_number" is the parameter number. + // "name" is the name of the parameter instruction. + template + std::unique_ptr CreateR0Parameter( + NativeT value, int64 parameter_number, const string& name, + ComputationBuilder* builder, ComputationDataHandle* data_handle); + // Create a parameter instruction that wraps the given values and then stores // into "data_handle" the global handle for that parameter. // @@ -370,6 +380,17 @@ void ClientLibraryTestBase::ComputeAndCompareR4( arguments, error); } +template +std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( + NativeT value, int64 parameter_number, const string& name, + ComputationBuilder* builder, ComputationDataHandle* data_handle) { + std::unique_ptr literal = LiteralUtil::CreateR0(value); + std::unique_ptr data = + client_->TransferToServer(*literal).ConsumeValueOrDie(); + *data_handle = builder->Parameter(parameter_number, literal->shape(), name); + return data; +} + template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number,