提交 4a2abacb 编写于 作者: P Peter Hawkins 提交者: TensorFlower Gardener

[XLA:Python] Add CustomCall support to Python LocalComputationBuilder.

PiperOrigin-RevId: 225205868
上级 46afcd06
......@@ -783,6 +783,21 @@ LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation,
return xla::Call(&builder_, local_computation.computation(), xla_ops);
}
LocalOp LocalComputationBuilder::CustomCall(
const string& call_target_name, absl::Span<const LocalOp> operands,
const Shape& shape_with_layout,
const std::vector<Shape>& operand_shapes_with_layout,
const string& opaque) {
std::vector<XlaOp> xla_ops;
xla_ops.reserve(operands.size());
for (const auto& op : operands) {
xla_ops.push_back(op.op());
}
return xla::CustomCallWithLayout(&builder_, call_target_name, xla_ops,
shape_with_layout,
operand_shapes_with_layout, opaque);
}
LocalOp LocalComputationBuilder::Transpose(
const LocalOp& operand, absl::Span<const int64> permutation) {
return xla::Transpose(operand.op(), permutation);
......
......@@ -352,6 +352,12 @@ class LocalComputationBuilder {
LocalOp Call(const LocalComputation& local_computation,
absl::Span<const LocalOp> operands);
LocalOp CustomCall(const string& call_target_name,
absl::Span<const LocalOp> operands,
const Shape& shape_with_layout,
const std::vector<Shape>& operand_shapes_with_layout,
const string& opaque);
LocalOp Transpose(const LocalOp& operand,
absl::Span<const int64> permutation);
......
......@@ -1147,6 +1147,7 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Cholesky;
%unignore xla::swig::LocalComputationBuilder::QR;
%unignore xla::swig::LocalComputationBuilder::TriangularSolve;
%unignore xla::swig::LocalComputationBuilder::CustomCall;
%unignore xla::swig::DeleteLocalComputation;
%unignore xla::swig::DestructureLocalShapedBufferTuple;
%unignore xla::swig::DestructureXrtAllocationTuple;
......
......@@ -1102,6 +1102,31 @@ class ComputationBuilder(object):
"""
return self._client.Call(computation_to_apply.computation, operands)
def CustomCall(self,
call_target_name,
operands,
shape_with_layout,
operand_shapes_with_layout,
opaque=None):
"""Enqueues a custom call operation onto the computation.
Args:
call_target_name: the name of the function to call.
operands: an iterable of LocalOp. The number and types of operands must
match the arity of `operand_shapes_with_layout`.
shape_with_layout: the shape of the operator's output, with layout.
operand_shapes_with_layout: the shapes of `operands`, including the
expected layouts.
opaque: an opaque string passed to the backend.
Returns:
A LocalOp representing the added custom call op.
"""
opaque = opaque or ''
return self._client.CustomCall(call_target_name, operands,
shape_with_layout,
operand_shapes_with_layout, opaque)
def Map(self, operands, computation_to_apply, dimensions):
"""Enqueues a map operation onto the computation.
......
......@@ -481,7 +481,9 @@ Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) {
const Shape& operand_shape_with_layout =
custom_call->operand_shapes_with_layout()[i];
TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(),
operand_shape_with_layout));
operand_shape_with_layout))
<< custom_call->operand(i)->shape().ToString() << " operand "
<< operand_shape_with_layout.ToString();
TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册