提交 22af085f 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

[XLA] add Iota and BroadcastedIota to local Python client

PiperOrigin-RevId: 225256432
上级 1b7e1c7c
......@@ -647,6 +647,15 @@ LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) {
return xla::ConstantLiteral(&builder_, literal);
}
LocalOp LocalComputationBuilder::Iota(PrimitiveType element_type, int64 size) {
return xla::Iota(&builder_, element_type, size);
}
LocalOp LocalComputationBuilder::BroadcastedIota(const Shape& shape,
int64 dimension) {
return xla::Iota(&builder_, shape, dimension);
}
LocalOp LocalComputationBuilder::Broadcast(
const LocalOp& operand, absl::Span<const int64> broadcast_sizes) {
return xla::Broadcast(operand.op(), broadcast_sizes);
......
......@@ -286,6 +286,10 @@ class LocalComputationBuilder {
LocalOp ConstantLiteral(const Literal& literal);
LocalOp Iota(PrimitiveType element_type, int64 size);
LocalOp BroadcastedIota(const Shape& shape, int64 dimension);
LocalOp Broadcast(const LocalOp& operand,
absl::Span<const int64> broadcast_sizes);
......
......@@ -1051,6 +1051,8 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Outfeed;
%unignore xla::swig::LocalComputationBuilder::ConstantLiteral;
%unignore xla::swig::LocalComputationBuilder::ConstantR0;
%unignore xla::swig::LocalComputationBuilder::Iota;
%unignore xla::swig::LocalComputationBuilder::BroadcastedIota;
%unignore xla::swig::LocalComputationBuilder::Broadcast;
%unignore xla::swig::LocalComputationBuilder::BroadcastInDim;
%unignore xla::swig::LocalComputationBuilder::Pad;
......
......@@ -831,6 +831,33 @@ class ComputationBuilder(object):
return self.ParameterWithShape(
Shape.from_pyval(value), name=name, parameter_num=parameter_num)
def Iota(self, dtype, size):
"""Enqueues an iota constant onto the computation.
Args:
dtype: expected numpy dtype of the output.
size: integer, the number of elements in the array.
Returns:
A LocalOp representing the added iota constant.
"""
element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))]
return self._client.Iota(element_type, size)
def BroadcastedIota(self, dtype, shape, dimension):
"""Enqueues a broadcasted iota constant onto the computation.
Args:
dtype: expected numpy dtype of the output.
shape: tuple of integers, the expected output shape (dimensions).
dimension: positive integer, dimension along which to increment values.
Returns:
A LocalOp representing the added broadcasted iota constant.
"""
xla_shape = Shape.array_shape(dtype, shape)
return self._client.BroadcastedIota(xla_shape, dimension)
def Broadcast(self, operand, sizes):
"""Enqueues a broadcast operation onto the computation.
......
......@@ -146,6 +146,17 @@ class ComputationsWithConstantsTest(LocalComputationTest):
c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.))
self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.])
def testIota(self):
c = self._NewComputation()
c.Iota(np.float32, 10)
self._ExecuteAndCompareExact(c, expected=np.arange(10, dtype=np.float32))
def testBroadcastedIota(self):
c = self._NewComputation()
c.BroadcastedIota(np.int64, (2, 3), 1)
expected = np.array([[0, 1, 2], [0, 1, 2]], dtype=np.int64)
self._ExecuteAndCompareExact(c, expected=expected)
def testBooleanAnd(self):
c = self._NewComputation()
c.And(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册