提交 0b2d2319 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

[XLA] Add test for broadcast of a scalar provided as a parameter.

Change: 150490700
上级 815c7a76
......@@ -48,6 +48,19 @@ XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) {
ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) {
ComputationBuilder b(client_, TestName());
ComputationDataHandle src;
std::unique_ptr<GlobalData> param_data =
CreateR0Parameter<float>(2.25f, /*parameter_number=*/0, /*name=*/"src",
/*builder=*/&b, /*data_handle=*/&src);
b.Broadcast(src, {2, 3});
Array2D<float> expected(2, 3, 2.25);
ComputeAndCompareR2<float>(&b, expected, {param_data.get()},
ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) {
ComputationBuilder b(client_, TestName());
b.Broadcast(b.ConstantR0<float>(2.25), {2, 0});
......
......@@ -216,6 +216,16 @@ class ClientLibraryTestBase : public ::testing::Test {
const int rows, const int cols, const int rows_padded,
const int cols_padded);
// Create a parameter instruction that wraps a given value and then stores
// into "data_handle" the global handle for that parameter.
//
// "parameter_number" is the parameter number.
// "name" is the name of the parameter instruction.
template <typename NativeT>
std::unique_ptr<GlobalData> CreateR0Parameter(
NativeT value, int64 parameter_number, const string& name,
ComputationBuilder* builder, ComputationDataHandle* data_handle);
// Create a parameter instruction that wraps the given values and then stores
// into "data_handle" the global handle for that parameter.
//
......@@ -370,6 +380,17 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
arguments, error);
}
template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
NativeT value, int64 parameter_number, const string& name,
ComputationBuilder* builder, ComputationDataHandle* data_handle) {
std::unique_ptr<Literal> literal = LiteralUtil::CreateR0(value);
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
*data_handle = builder->Parameter(parameter_number, literal->shape(), name);
return data;
}
template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册