提交 b35e03a7 编写于 作者: P Prakalp Srivastava 提交者: TensorFlower Gardener

Add infeed_config attribute to InfeedOp.

PiperOrigin-RevId: 286224973
Change-Id: If77849b23b0ae49188df7ceb464908a8515b49ce
上级 8ae50866
......@@ -287,6 +287,12 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
llvm::ArrayRef<Value*>(operands.begin() + 2, operands.end()))
.getOperation();
}
case HloOpcode::kInfeed: {
attributes.push_back(builder_->getNamedAttr(
"infeed_config", mlir::StringAttr::get(instruction->infeed_config(),
builder_->getContext())));
MakeAndReturn(InfeedOp);
}
case HloOpcode::kPad: {
const auto& padding_config = instruction->padding_config();
llvm::SmallVector<int64_t, 4> edge_padding_low;
......@@ -448,7 +454,6 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
NoAttributeCase(kExp, ExpOp);
NoAttributeCase(kExpm1, Expm1Op);
NoAttributeCase(kFloor, FloorOp);
NoAttributeCase(kInfeed, InfeedOp);
NoAttributeCase(kImag, ImagOp);
NoAttributeCase(kLog, LogOp);
NoAttributeCase(kLog1p, Log1pOp);
......
......@@ -343,7 +343,10 @@ def HLO_InfeedOp : HLO_Op<"infeed", []> {
See https://www.tensorflow.org/xla/operation_semantics#infeed.
}];
let arguments = (ins HLO_Token:$token);
let arguments = (ins
HLO_Token:$token,
DefaultValuedAttr<StrAttr, "">:$infeed_config
);
let results = (outs HLO_Tuple);
let hasCustomHLOConverter = 1;
}
......
......@@ -558,8 +558,8 @@ LogicalResult ExportXlaOp(InfeedOp op, OpLoweringContext ctx) {
// The shape argument expected by the xla client API is the type of the first
// element in the result tuple.
auto result_type = op.getType().cast<mlir::TupleType>().getType(0);
value_map[op] = xla::InfeedWithToken(value_map[op.token()],
xla::TypeToShape(result_type));
value_map[op] = xla::InfeedWithToken(
value_map[op.token()], xla::TypeToShape(result_type), op.infeed_config());
return success();
}
......
......@@ -398,13 +398,13 @@ func @main(%arg0: tuple<tensor<f32>, tensor<i32>>) -> tensor<f32> {
// CHECK: HloModule
func @main(%arg0: !xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token> {
%0 = "xla_hlo.infeed"(%arg0) : (!xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token>
return %0 : tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token>
}
%0 = "xla_hlo.infeed"(%arg0) {infeed_config = "foobar"} : (!xla_hlo.token) -> tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token>
return %0 : tuple<tuple<tensor<3xi32>, tensor<i1>>, !xla_hlo.token>
}
// CHECK: ENTRY
// CHECK: [[ARG:%.*]] = token[] parameter(0)
// CHECK: ROOT %[[RESULT:.*]] = ((s32[3], pred[]), token[]) infeed(token[] [[ARG]])
// CHECK: ROOT %[[RESULT:.*]] = ((s32[3], pred[]), token[]) infeed(token[] [[ARG]]), infeed_config="foobar"
// -----
......
......@@ -364,6 +364,16 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
ROOT %imag.3 = f32[4] imag(c64[4] %Arg_0.1)
}
// CHECK-LABEL: func @test_infeed
// CHECK-SAME: ([[TOKEN:%.*]]: !xla_hlo.token) -> tuple<tensor<3xi32>, !xla_hlo.token> {
%test_infeed (token0: token[]) -> (s32[3], token[]) {
%token0 = token[] parameter(0)
// CHECK-NEXT: "xla_hlo.infeed"([[TOKEN]])
// CHECK-SAME: infeed_config = "foobar"
ROOT %infeed = (s32[3], token[]) infeed(token[] %token0), infeed_config="foobar"
}
// CHECK-LABEL: func @test_iota_1() -> tensor<4xf32> {
%test_iota_1 () -> f32[4] {
// CHECK-NEXT: "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册