提交 42bab294 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

NFC: Clean up XLA op definitions

Avoid diamond inheritance in tablegen with classes that inherit from Op and also inherit from Results and Arguments.
Remove extraneous Operand and Result names in cases where there's only one. There will already be a generated getOperand/getResult method. Having an additional "res" accessor doesn't add value.
In places where naming a single result is necessary for traits that check directly against attribute names, named single results "result".
Avoid non-elementwise ops inheriting from elementwise subclasses.

PiperOrigin-RevId: 262642541
上级 5293302f
......@@ -111,8 +111,11 @@ def XLA_IotaOp : XLA_Op<"iota", [NoSideEffect]> {
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
class XLA_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits>:
XLA_Op<mnemonic, traits>, Arguments<(ins XLA_Tensor:$operand)>,
Results<(outs XLA_Tensor:$res)>;
XLA_Op<mnemonic, traits> {
let arguments = (ins XLA_Tensor);
let results = (outs XLA_Tensor);
}
def XLA_AbsOp: XLA_UnaryElementwiseOp<"abs", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Absolute value operator";
......@@ -196,15 +199,14 @@ def XLA_TanhOp: XLA_UnaryElementwiseOp<"tanh",
def BroadcastDimAttr : OptionalAttr<ElementsAttr>;
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
class XLA_BinaryElementwiseOp<string mnemonic,
list<OpTrait> traits, dag args = (ins)> :
XLA_Op<mnemonic, traits>,
Arguments<(
ins XLA_Tensor:$lhs,
XLA_Tensor:$rhs,
BroadcastDimAttr:$broadcast_dimensions
)>,
Results<(outs XLA_Tensor:$res)> {
class XLA_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
XLA_Op<mnemonic, traits> {
let arguments = (ins
XLA_Tensor:$lhs,
XLA_Tensor:$rhs,
BroadcastDimAttr:$broadcast_dimensions
);
let results = (outs XLA_Tensor);
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
}
......@@ -302,7 +304,7 @@ def XLA_WhileOp: XLA_Op<"while", [NoSideEffect, SameOperandsAndResultType]> {
SymbolRefAttr:$body
);
let results = (outs Variadic<XLA_TensorOrTuple>:$res);
let results = (outs Variadic<XLA_TensorOrTuple>);
// TODO(b/129422361): WhileOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
......@@ -324,7 +326,7 @@ def XLA_ReduceOp: XLA_Op<"reduce", [NoSideEffect]> {
ElementsAttr:$dimensions
);
let results = (outs Variadic<XLA_Tensor>:$res);
let results = (outs Variadic<XLA_Tensor>);
// TODO(b/129422361): ReduceOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
......@@ -363,7 +365,7 @@ def XLA_TupleOp : XLA_Op<"tuple", [NoSideEffect]> {
}];
let arguments = (ins Variadic<XLA_TensorOrTuple>:$val);
let results = (outs XLA_Tuple:$res);
let results = (outs XLA_Tuple);
// TupleOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
......@@ -418,7 +420,7 @@ def XLA_CompareOp: XLA_Op<"compare",
BroadcastDimAttr:$broadcast_dimensions,
XLA_ComparisonDirectionAttr:$comparison_direction
);
let results = (outs XLA_PredTensor:$res);
let results = (outs XLA_PredTensor);
let summary = "Comparison operator";
let description = [{
......@@ -433,7 +435,7 @@ def XLA_CompareOp: XLA_Op<"compare",
// XLA Slice definitions.
//===----------------------------------------------------------------------===//
def XLA_SliceOp: XLA_UnaryElementwiseOp<
def XLA_SliceOp: XLA_Op<
"slice",
[NoSideEffect, SameOperandsAndResultElementType,
AllTypesMatch<["start_indices", "limit_indices"]>]> {
......@@ -443,7 +445,7 @@ def XLA_SliceOp: XLA_UnaryElementwiseOp<
ElementsAttr:$limit_indices
);
let results = (outs XLA_Tensor:$res);
let results = (outs XLA_Tensor);
let summary = "Slice operator";
......@@ -458,15 +460,15 @@ def XLA_SliceOp: XLA_UnaryElementwiseOp<
let hasCustomHLOConverter = 1;
}
def XLA_DynamicUpdateSliceOp: XLA_UnaryElementwiseOp<"dynamic-update-slice",
[NoSideEffect, AllElementTypesMatch<["operand", "res"]>]> {
def XLA_DynamicUpdateSliceOp: XLA_Op<"dynamic-update-slice",
[NoSideEffect, AllElementTypesMatch<["operand", "result"]>]> {
let arguments = (ins
XLA_Tensor:$operand,
XLA_Tensor:$update,
Variadic<XLA_Tensor>:$start_indices
);
let results = (outs XLA_Tensor:$res);
let results = (outs XLA_Tensor:$result);
let summary = "Dynamic Update Slice operator";
......@@ -505,9 +507,7 @@ def XLA_BatchNormInferenceOp : XLA_Op<"batch_norm_inference", [NoSideEffect]> {
I64Attr:$feature_index
);
let results = (outs
XLA_Tensor:$res
);
let results = (outs XLA_Tensor);
}
def XLA_BroadcastOp : XLA_Op<"broadcast",
......@@ -531,7 +531,7 @@ def XLA_BroadcastOp : XLA_Op<"broadcast",
ElementsAttr:$broadcast_sizes
);
let results = (outs XLA_Tensor:$res);
let results = (outs XLA_Tensor);
// TODO(b/129012527) These should be expressed as type constraints.
let verifier = [{
......@@ -548,7 +548,7 @@ def XLA_BroadcastOp : XLA_Op<"broadcast",
"broadcast_sizes has rank {0} instead of rank 1", sizesRank));
}
auto resultType = res()->getType().cast<RankedTensorType>();
auto resultType = getResult()->getType().cast<RankedTensorType>();
auto resultRank = resultType.getRank();
auto operandType = operand()->getType().cast<RankedTensorType>();
auto operandRank = operandType.getRank();
......@@ -613,7 +613,7 @@ def XLA_BroadcastInDimOp : XLA_Op<"broadcast_in_dim",
BroadcastDimAttr:$broadcast_dimensions
);
let results = (outs XLA_Tensor:$res);
let results = (outs XLA_Tensor);
// TODO(b/129012527) These should be expressed as type constraints.
let verifier = [{
......@@ -651,7 +651,7 @@ def XLA_BroadcastInDimOp : XLA_Op<"broadcast_in_dim",
dimensionsSize, operandRank));
}
auto resultType = res()->getType().cast<RankedTensorType>();
auto resultType = getResult()->getType().cast<RankedTensorType>();
auto resultRank = resultType.getRank();
if (resultRank < operandRank) {
return emitOpError(
......@@ -706,9 +706,7 @@ def XLA_ClampOp : XLA_Op<"clamp",
XLA_Tensor:$max
);
let results = (outs
XLA_Tensor:$res
);
let results = (outs XLA_Tensor);
// TODO(b/129012527) These should be expressed as type constraints.
let verifier = [{
......@@ -783,7 +781,7 @@ def XLA_ConcatenateOp : XLA_Op<"concatenate",
return success();
}];
let results = (outs XLA_Tensor:$res);
let results = (outs XLA_Tensor);
// TODO(b/129422361) ConcatOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
......@@ -803,20 +801,23 @@ def XLA_ConvOp : XLA_Op<"conv", [NoSideEffect]> {
XLA_Tensor:$rhs
);
let results = (outs XLA_Tensor:$res);
let results = (outs XLA_Tensor);
// TODO(b/129422361) Needs additional work to handle attributes.
// Conv has custom handling because its other args are passed as attributes
let hasCustomHLOConverter = 1;
}
def XLA_CopyOp: XLA_UnaryElementwiseOp<"copy", [NoSideEffect, SameOperandsAndResultType]> {
def XLA_CopyOp: XLA_Op<"copy", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Copy operator";
let description = [{
Returns a copy of `operand`.
}];
let arguments = (ins XLA_Tensor);
let results = (outs XLA_Tensor);
// TODO(b/129422361) Implement special handling.
// Copy has an HloOpcode, but is not one of the ops defined in xla_builder.
let hasCustomHLOConverter = 1;
......@@ -828,7 +829,7 @@ def XLA_DotOp: XLA_Op<"dot", [NoSideEffect]> {
XLA_Tensor:$rhs,
XLA_PrecisionConfigAttr:$precision_config
);
let results = (outs XLA_Tensor:$res);
let results = (outs XLA_Tensor);
let description = [{
Performs dot products between vectors, vector/matrix and matrix/matrix
......@@ -849,7 +850,7 @@ def XLA_GatherOp: XLA_Op<"gather", [NoSideEffect]> {
ElementsAttr: $start_index_map
);
let results = (outs XLA_Tensor:$res);
let results = (outs XLA_Tensor);
let summary = "Gather operator";
......@@ -868,7 +869,7 @@ def XLA_ReshapeOp: XLA_Op<"reshape",
[NoSideEffect, SameOperandsAndResultElementType]> {
let arguments = (ins XLA_Tensor:$operand);
let results = (outs XLA_Tensor:$res);
let results = (outs XLA_Tensor);
let summary = "Reshape operator";
......@@ -909,7 +910,7 @@ def XLA_SelectOp: XLA_Op<"select", [NoSideEffect]> {
XLA_Tensor:$on_false
);
let results = (outs XLA_Tensor:$res);
let results = (outs XLA_Tensor);
// TODO(b/129012527) These should be expressed as type constraints.
let verifier = [{
......@@ -956,7 +957,7 @@ def XLA_ReverseOp: XLA_Op<"reverse",
ElementsAttr:$dimensions
);
let results = (outs XLA_Tensor:$res);
let results = (outs XLA_Tensor);
// TODO(b/129422361): ReverseOp has a custom constructor for HLO.
let hasCustomHLOConverter = 1;
......@@ -981,7 +982,7 @@ def XLA_PadOp: XLA_Op<"pad",
ElementsAttr: $interior_padding
);
let results = (outs XLA_Tensor:$res);
let results = (outs XLA_Tensor);
let description = [{
Pads the `operand` according to TBD.
......@@ -1052,7 +1053,7 @@ def XLA_TransposeOp: XLA_Op<"transpose",
XLA_Tensor:$operand,
ElementsAttr:$permutation
);
let results = (outs XLA_Tensor:$res);
let results = (outs XLA_Tensor);
// TODO(b/129012527) These should be expressed as type constraints.
let verifier = [{
......@@ -1078,7 +1079,7 @@ def XLA_TransposeOp: XLA_Op<"transpose",
permutationSize, operandRank));
}
auto resultType = res()->getType().cast<RankedTensorType>();
auto resultType = getResult()->getType().cast<RankedTensorType>();
auto resultRank = resultType.getRank();
if (resultRank != operandRank) {
return emitOpError(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册