提交 4fe05f35 编写于 作者: P Peter Hawkins 提交者: TensorFlower Gardener

[XLA:CPU] Add support for CustomCall targets that return tuples.

Populate the tuple index table of the return value; the callee cannot do this since it does not know the buffer assignments.

Explicitly enable custom_call_test only for cpu in the BUILD file, rather than disabling it on non-CPU backends. These tests would not work on any non-CPU backend.

PiperOrigin-RevId: 225048065
上级 31666006
......@@ -2271,6 +2271,22 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
/*isVarArg=*/false)));
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
// Write the tuple table if the output is a tuple.
if (ShapeUtil::IsTuple(custom_call->shape())) {
std::vector<llvm::Value*> base_ptrs;
for (int i = 0; i < ShapeUtil::TupleElementCount(custom_call->shape());
++i) {
const Shape& elem_shape =
ShapeUtil::GetTupleElementShape(custom_call->shape(), i);
TF_RET_CHECK(!ShapeUtil::IsTuple(elem_shape))
<< "Nested tuples not implemented";
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
assignment_.GetUniqueSlice(custom_call, {i}));
llvm::Value* addr = EmitBufferPointer(slice, elem_shape);
base_ptrs.push_back(addr);
}
llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_, module_);
}
auto* output_address_arg =
PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
......
# Description:
# Base testing infrastructure for XLA.
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library")
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"tf_cuda_tests_tags",
)
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
licenses(["notice"]) # Apache 2.0
package(
......@@ -23,17 +30,6 @@ filegroup(
]),
)
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test_library")
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites")
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"tf_cuda_tests_tags",
)
# Generate test_suites for all backends, named "${backend}_tests".
generate_backend_suites()
......@@ -1348,6 +1344,7 @@ xla_test(
xla_test(
name = "custom_call_test",
srcs = ["custom_call_test.cc"],
backends = ["cpu"],
deps = [
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
......
......@@ -54,11 +54,20 @@ void Add1ToValues(float* out, float** in) {
out[2] = array[2] + 1;
out[3] = array[3] + 1;
}
void F32TupleSwap(float** out, float** in) {
TF_ANNOTATE_MEMORY_IS_INITIALIZED(in[0], sizeof(float));
TF_ANNOTATE_MEMORY_IS_INITIALIZED(in[1], sizeof(float));
*out[0] = *in[1];
*out[1] = *in[0];
}
} // namespace
REGISTER_CUSTOM_CALL_TARGET(R0F32Add2);
REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum);
REGISTER_CUSTOM_CALL_TARGET(Add1ToValues);
REGISTER_CUSTOM_CALL_TARGET(F32TupleSwap);
namespace xla {
namespace {
......@@ -69,7 +78,7 @@ class CustomCallTest : public HloTestBase {
Shape r2f32_ = ShapeUtil::MakeShape(F32, {2, 2});
};
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) {
XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2) {
auto module = CreateNewUnverifiedModule();
auto builder = HloComputation::Builder(TestName());
......@@ -84,7 +93,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) {
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
}
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
XLA_TEST_F(CustomCallTest, CustomCallR2F32Reduce) {
auto module = CreateNewUnverifiedModule();
auto builder = HloComputation::Builder(TestName());
......@@ -105,7 +114,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
}
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) {
XLA_TEST_F(CustomCallTest, UsedInOtherComputations) {
auto module = CreateNewUnverifiedModule();
auto b = HloComputation::Builder(TestName());
......@@ -129,7 +138,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) {
Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result);
}
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) {
XLA_TEST_F(CustomCallTest, InputAndOutputLayoutDiffer) {
auto module = CreateNewUnverifiedModule();
auto b = HloComputation::Builder(TestName());
......@@ -151,7 +160,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) {
LiteralTestUtil::ExpectR2Equal<float>({{2.f, 4.f}, {3.f, 5.f}}, result);
}
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) {
XLA_TEST_F(CustomCallTest, LayoutConstrained) {
// The argument and result of the computation are set to different layouts,
// but the custom call is layout constrained to a fixed operand and result
// layout, so the correct result should be produced.
......@@ -176,6 +185,26 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) {
LiteralTestUtil::ExpectR2Equal<float>({{2.f, 3.f}, {4.f, 5.f}}, result);
}
XLA_TEST_F(CustomCallTest, TupleOutput) {
const char* kModuleStr = R"(
HloModule m
test {
p0 = f32[] parameter(0)
p1 = f32[] parameter(1)
ROOT %custom-call = (f32[], f32[]) custom-call(f32[] %p0, f32[] %p1), custom_call_target="F32TupleSwap", operand_layout_constraints={f32[], f32[]}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr));
Literal arg0 = LiteralUtil::CreateR0<float>(7.f);
Literal arg1 = LiteralUtil::CreateR0<float>(42.f);
Literal expected = LiteralUtil::MakeTuple({&arg1, &arg0});
Literal result = ExecuteAndTransfer(std::move(module), {&arg0, &arg1});
EXPECT_EQ(result, expected);
}
class CustomCallClientAPITest : public ClientLibraryTestBase {};
// When using the client API, CustomCall targets can't begin with '$' -- these
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册