提交 0bd78003 编写于 作者: S Sanjoy Das 提交者: Michael Case

[XLA:CPU] Generate correct IR for integer clamp

PiperOrigin-RevId: 184037078
上级 9f75f8e6
...@@ -1043,17 +1043,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp( ...@@ -1043,17 +1043,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE, is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
lhs_value, rhs_value, ir_builder_); lhs_value, rhs_value, ir_builder_);
case HloOpcode::kMinimum: case HloOpcode::kMinimum:
return ir_builder_->CreateSelect( return EmitIntegralMin(lhs_value, rhs_value, is_signed);
ir_builder_->CreateICmp(
is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
lhs_value, rhs_value),
lhs_value, rhs_value);
case HloOpcode::kMaximum: case HloOpcode::kMaximum:
return ir_builder_->CreateSelect( return EmitIntegralMax(lhs_value, rhs_value, is_signed);
ir_builder_->CreateICmp(
is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
lhs_value, rhs_value),
lhs_value, rhs_value);
case HloOpcode::kAnd: case HloOpcode::kAnd:
return ir_builder_->CreateAnd(lhs_value, rhs_value); return ir_builder_->CreateAnd(lhs_value, rhs_value);
case HloOpcode::kOr: case HloOpcode::kOr:
...@@ -1070,6 +1062,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp( ...@@ -1070,6 +1062,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
} }
} }
llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
llvm::Value* rhs_value,
bool is_signed) const {
return ir_builder_->CreateSelect(
ir_builder_->CreateICmp(
is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
lhs_value, rhs_value),
lhs_value, rhs_value);
}
llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
llvm::Value* rhs_value,
bool is_signed) const {
return ir_builder_->CreateSelect(
ir_builder_->CreateICmp(
is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
lhs_value, rhs_value),
lhs_value, rhs_value);
}
llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
int64 operand_no) const { int64 operand_no) const {
...@@ -1366,7 +1378,18 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( ...@@ -1366,7 +1378,18 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
TF_ASSIGN_OR_RETURN(llvm::Value * max_value, TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
operand_to_generator.at(hlo->operand(2))( operand_to_generator.at(hlo->operand(2))(
ElementwiseSourceIndex(index, *hlo, 2))); ElementwiseSourceIndex(index, *hlo, 2)));
return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); PrimitiveType prim_type = hlo->shape().element_type();
if (primitive_util::IsFloatingPointType(prim_type)) {
return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
} else if (primitive_util::IsIntegralType(prim_type)) {
bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
return EmitIntegralMin(
max_value, EmitIntegralMax(min_value, arg_value, is_signed),
is_signed);
} else {
return Unimplemented("Clamp unimplemented for %s",
PrimitiveType_Name(prim_type).c_str());
}
}; };
case HloOpcode::kReducePrecision: case HloOpcode::kReducePrecision:
return [this, hlo, &operand_to_generator]( return [this, hlo, &operand_to_generator](
......
...@@ -86,6 +86,12 @@ class ElementalIrEmitter { ...@@ -86,6 +86,12 @@ class ElementalIrEmitter {
virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value, virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
llvm::Value* rhs_value) const; llvm::Value* rhs_value) const;
llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
bool is_signed) const;
llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
bool is_signed) const;
virtual StatusOr<llvm::Value*> EmitErfInv(PrimitiveType prim_type, virtual StatusOr<llvm::Value*> EmitErfInv(PrimitiveType prim_type,
llvm::Value* value) const; llvm::Value* value) const;
......
...@@ -1893,6 +1893,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { ...@@ -1893,6 +1893,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
error_spec_); error_spec_);
} }
XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) {
ComputationBuilder builder(client_, TestName());
auto min_vector = builder.ConstantR1<int32>({1, -6, 1, 2, 0, -5});
auto arg_vector = builder.ConstantR1<int32>({2, 10, -5, 1, 4, 10});
auto max_vector = builder.ConstantR1<int32>({3, 0, 25, 5, 123, -1});
auto clamp = builder.Clamp(min_vector, arg_vector, max_vector);
ComputeAndCompareR1<int32>(&builder, {2, 0, 1, 2, 4, -1}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) {
ComputationBuilder builder(client_, TestName());
auto min_vector = builder.ConstantR1<uint32>({1, 2, 1, 2, 0, ~0u - 4});
auto arg_vector = builder.ConstantR1<uint32>({2, 10, 5, 1, 4, 10});
auto max_vector = builder.ConstantR1<uint32>({3, 5, 25, 5, 123, ~0u});
auto clamp = builder.Clamp(min_vector, arg_vector, max_vector);
ComputeAndCompareR1<uint32>(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
ComputationBuilder builder(client_, TestName()); ComputationBuilder builder(client_, TestName());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册