提交 0bd78003 编写于 作者: S Sanjoy Das 提交者: Michael Case

[XLA:CPU] Generate correct IR for integer clamp

PiperOrigin-RevId: 184037078
上级 9f75f8e6
......@@ -1043,17 +1043,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
lhs_value, rhs_value, ir_builder_);
case HloOpcode::kMinimum:
return ir_builder_->CreateSelect(
ir_builder_->CreateICmp(
is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
lhs_value, rhs_value),
lhs_value, rhs_value);
return EmitIntegralMin(lhs_value, rhs_value, is_signed);
case HloOpcode::kMaximum:
return ir_builder_->CreateSelect(
ir_builder_->CreateICmp(
is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
lhs_value, rhs_value),
lhs_value, rhs_value);
return EmitIntegralMax(lhs_value, rhs_value, is_signed);
case HloOpcode::kAnd:
return ir_builder_->CreateAnd(lhs_value, rhs_value);
case HloOpcode::kOr:
......@@ -1070,6 +1062,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
}
}
llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
llvm::Value* rhs_value,
bool is_signed) const {
return ir_builder_->CreateSelect(
ir_builder_->CreateICmp(
is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
lhs_value, rhs_value),
lhs_value, rhs_value);
}
llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
llvm::Value* rhs_value,
bool is_signed) const {
return ir_builder_->CreateSelect(
ir_builder_->CreateICmp(
is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
lhs_value, rhs_value),
lhs_value, rhs_value);
}
llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
int64 operand_no) const {
......@@ -1366,7 +1378,18 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
operand_to_generator.at(hlo->operand(2))(
ElementwiseSourceIndex(index, *hlo, 2)));
return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
PrimitiveType prim_type = hlo->shape().element_type();
if (primitive_util::IsFloatingPointType(prim_type)) {
return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
} else if (primitive_util::IsIntegralType(prim_type)) {
bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
return EmitIntegralMin(
max_value, EmitIntegralMax(min_value, arg_value, is_signed),
is_signed);
} else {
return Unimplemented("Clamp unimplemented for %s",
PrimitiveType_Name(prim_type).c_str());
}
};
case HloOpcode::kReducePrecision:
return [this, hlo, &operand_to_generator](
......
......@@ -86,6 +86,12 @@ class ElementalIrEmitter {
virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
llvm::Value* rhs_value) const;
llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
bool is_signed) const;
llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
bool is_signed) const;
virtual StatusOr<llvm::Value*> EmitErfInv(PrimitiveType prim_type,
llvm::Value* value) const;
......
......@@ -1893,6 +1893,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) {
ComputationBuilder builder(client_, TestName());
auto min_vector = builder.ConstantR1<int32>({1, -6, 1, 2, 0, -5});
auto arg_vector = builder.ConstantR1<int32>({2, 10, -5, 1, 4, 10});
auto max_vector = builder.ConstantR1<int32>({3, 0, 25, 5, 123, -1});
auto clamp = builder.Clamp(min_vector, arg_vector, max_vector);
ComputeAndCompareR1<int32>(&builder, {2, 0, 1, 2, 4, -1}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) {
ComputationBuilder builder(client_, TestName());
auto min_vector = builder.ConstantR1<uint32>({1, 2, 1, 2, 0, ~0u - 4});
auto arg_vector = builder.ConstantR1<uint32>({2, 10, 5, 1, 4, 10});
auto max_vector = builder.ConstantR1<uint32>({3, 5, 25, 5, 123, ~0u});
auto clamp = builder.Clamp(min_vector, arg_vector, max_vector);
ComputeAndCompareR1<uint32>(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
ComputationBuilder builder(client_, TestName());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册