提交 45573542 编写于 作者: D David Majnemer 提交者: TensorFlower Gardener

[XLA] Add support for iota of PRED

PiperOrigin-RevId: 223413074
上级 cac33744
......@@ -2245,13 +2245,15 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
: iota->shape();
PrimitiveType component_element_type = component_shape.element_type();
llvm::Value* iota_result;
if (ShapeUtil::ElementIsIntegral(component_shape)) {
if (primitive_util::IsIntegralType(component_element_type) ||
component_element_type == PRED) {
iota_result = b_->CreateIntCast(
elem_index_linear,
llvm_ir::PrimitiveTypeToIrType(component_element_type, module_),
/*isSigned=*/false);
} else {
TF_RET_CHECK(ShapeUtil::ElementIsFloating(component_shape))
TF_RET_CHECK(
primitive_util::IsFloatingPointType(component_element_type))
<< component_element_type;
llvm::Type* float_ir_type;
if (component_element_type == BF16) {
......
......@@ -113,5 +113,26 @@ INSTANTIATE_TEST_CASE_P(IotaR3TestInstantiation, IotaR3Test,
/*step=*/10),
::testing::Values(0, 1, 2)));
class IotaR3PredTest : public ClientLibraryTestBase,
public ::testing::WithParamInterface<int> {};
TEST_P(IotaR3PredTest, DoIt) {
const auto element_type = PRED;
const int64 num_elements = 2;
const int64 iota_dim = GetParam();
XlaBuilder builder(TestName() + "_" + PrimitiveType_Name(element_type));
std::vector<int64> dimensions = {42, 19};
dimensions.insert(dimensions.begin() + iota_dim, num_elements);
Iota(&builder, ShapeUtil::MakeShape(element_type, dimensions), iota_dim);
if (primitive_util::IsFloatingPointType(element_type)) {
ComputeAndCompare(&builder, {}, ErrorSpec{0.0001});
} else {
ComputeAndCompare(&builder, {});
}
}
INSTANTIATE_TEST_CASE_P(IotaR3PredTestInstantiation, IotaR3PredTest,
::testing::Values(0, 1, 2));
} // namespace
} // namespace xla
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册