提交 42bab294 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

NFC: Clean up XLA op definitions

Avoid diamond inheritance in tablegen with classes that inherit from Op and also inherit from Results and Arguments.
Remove extraneous Operand and Result names in cases where there's only one. There will already be a generated getOperand/getResult method. Having an additional "res" accessor doesn't add value.
In places where naming a single result is necessary for traits that check directly against attribute names, named single results "result".
Avoid non-elementwise ops inheriting from elementwise subclasses.

PiperOrigin-RevId: 262642541
上级 5293302f
...@@ -111,8 +111,11 @@ def XLA_IotaOp : XLA_Op<"iota", [NoSideEffect]> { ...@@ -111,8 +111,11 @@ def XLA_IotaOp : XLA_Op<"iota", [NoSideEffect]> {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
class XLA_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits>: class XLA_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits>:
XLA_Op<mnemonic, traits>, Arguments<(ins XLA_Tensor:$operand)>, XLA_Op<mnemonic, traits> {
Results<(outs XLA_Tensor:$res)>;
let arguments = (ins XLA_Tensor);
let results = (outs XLA_Tensor);
}
def XLA_AbsOp: XLA_UnaryElementwiseOp<"abs", [NoSideEffect, SameOperandsAndResultType]> { def XLA_AbsOp: XLA_UnaryElementwiseOp<"abs", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Absolute value operator"; let summary = "Absolute value operator";
...@@ -196,15 +199,14 @@ def XLA_TanhOp: XLA_UnaryElementwiseOp<"tanh", ...@@ -196,15 +199,14 @@ def XLA_TanhOp: XLA_UnaryElementwiseOp<"tanh",
def BroadcastDimAttr : OptionalAttr<ElementsAttr>; def BroadcastDimAttr : OptionalAttr<ElementsAttr>;
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
class XLA_BinaryElementwiseOp<string mnemonic, class XLA_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
list<OpTrait> traits, dag args = (ins)> : XLA_Op<mnemonic, traits> {
XLA_Op<mnemonic, traits>, let arguments = (ins
Arguments<( XLA_Tensor:$lhs,
ins XLA_Tensor:$lhs, XLA_Tensor:$rhs,
XLA_Tensor:$rhs, BroadcastDimAttr:$broadcast_dimensions
BroadcastDimAttr:$broadcast_dimensions );
)>, let results = (outs XLA_Tensor);
Results<(outs XLA_Tensor:$res)> {
let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }]; let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }];
let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }]; let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }];
} }
...@@ -302,7 +304,7 @@ def XLA_WhileOp: XLA_Op<"while", [NoSideEffect, SameOperandsAndResultType]> { ...@@ -302,7 +304,7 @@ def XLA_WhileOp: XLA_Op<"while", [NoSideEffect, SameOperandsAndResultType]> {
SymbolRefAttr:$body SymbolRefAttr:$body
); );
let results = (outs Variadic<XLA_TensorOrTuple>:$res); let results = (outs Variadic<XLA_TensorOrTuple>);
// TODO(b/129422361): WhileOp has special conversion logic to HLO. // TODO(b/129422361): WhileOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
...@@ -324,7 +326,7 @@ def XLA_ReduceOp: XLA_Op<"reduce", [NoSideEffect]> { ...@@ -324,7 +326,7 @@ def XLA_ReduceOp: XLA_Op<"reduce", [NoSideEffect]> {
ElementsAttr:$dimensions ElementsAttr:$dimensions
); );
let results = (outs Variadic<XLA_Tensor>:$res); let results = (outs Variadic<XLA_Tensor>);
// TODO(b/129422361): ReduceOp has special conversion logic to HLO. // TODO(b/129422361): ReduceOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
...@@ -363,7 +365,7 @@ def XLA_TupleOp : XLA_Op<"tuple", [NoSideEffect]> { ...@@ -363,7 +365,7 @@ def XLA_TupleOp : XLA_Op<"tuple", [NoSideEffect]> {
}]; }];
let arguments = (ins Variadic<XLA_TensorOrTuple>:$val); let arguments = (ins Variadic<XLA_TensorOrTuple>:$val);
let results = (outs XLA_Tuple:$res); let results = (outs XLA_Tuple);
// TupleOp has special conversion logic to HLO. // TupleOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
...@@ -418,7 +420,7 @@ def XLA_CompareOp: XLA_Op<"compare", ...@@ -418,7 +420,7 @@ def XLA_CompareOp: XLA_Op<"compare",
BroadcastDimAttr:$broadcast_dimensions, BroadcastDimAttr:$broadcast_dimensions,
XLA_ComparisonDirectionAttr:$comparison_direction XLA_ComparisonDirectionAttr:$comparison_direction
); );
let results = (outs XLA_PredTensor:$res); let results = (outs XLA_PredTensor);
let summary = "Comparison operator"; let summary = "Comparison operator";
let description = [{ let description = [{
...@@ -433,7 +435,7 @@ def XLA_CompareOp: XLA_Op<"compare", ...@@ -433,7 +435,7 @@ def XLA_CompareOp: XLA_Op<"compare",
// XLA Slice definitions. // XLA Slice definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def XLA_SliceOp: XLA_UnaryElementwiseOp< def XLA_SliceOp: XLA_Op<
"slice", "slice",
[NoSideEffect, SameOperandsAndResultElementType, [NoSideEffect, SameOperandsAndResultElementType,
AllTypesMatch<["start_indices", "limit_indices"]>]> { AllTypesMatch<["start_indices", "limit_indices"]>]> {
...@@ -443,7 +445,7 @@ def XLA_SliceOp: XLA_UnaryElementwiseOp< ...@@ -443,7 +445,7 @@ def XLA_SliceOp: XLA_UnaryElementwiseOp<
ElementsAttr:$limit_indices ElementsAttr:$limit_indices
); );
let results = (outs XLA_Tensor:$res); let results = (outs XLA_Tensor);
let summary = "Slice operator"; let summary = "Slice operator";
...@@ -458,15 +460,15 @@ def XLA_SliceOp: XLA_UnaryElementwiseOp< ...@@ -458,15 +460,15 @@ def XLA_SliceOp: XLA_UnaryElementwiseOp<
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
} }
def XLA_DynamicUpdateSliceOp: XLA_UnaryElementwiseOp<"dynamic-update-slice", def XLA_DynamicUpdateSliceOp: XLA_Op<"dynamic-update-slice",
[NoSideEffect, AllElementTypesMatch<["operand", "res"]>]> { [NoSideEffect, AllElementTypesMatch<["operand", "result"]>]> {
let arguments = (ins let arguments = (ins
XLA_Tensor:$operand, XLA_Tensor:$operand,
XLA_Tensor:$update, XLA_Tensor:$update,
Variadic<XLA_Tensor>:$start_indices Variadic<XLA_Tensor>:$start_indices
); );
let results = (outs XLA_Tensor:$res); let results = (outs XLA_Tensor:$result);
let summary = "Dynamic Update Slice operator"; let summary = "Dynamic Update Slice operator";
...@@ -505,9 +507,7 @@ def XLA_BatchNormInferenceOp : XLA_Op<"batch_norm_inference", [NoSideEffect]> { ...@@ -505,9 +507,7 @@ def XLA_BatchNormInferenceOp : XLA_Op<"batch_norm_inference", [NoSideEffect]> {
I64Attr:$feature_index I64Attr:$feature_index
); );
let results = (outs let results = (outs XLA_Tensor);
XLA_Tensor:$res
);
} }
def XLA_BroadcastOp : XLA_Op<"broadcast", def XLA_BroadcastOp : XLA_Op<"broadcast",
...@@ -531,7 +531,7 @@ def XLA_BroadcastOp : XLA_Op<"broadcast", ...@@ -531,7 +531,7 @@ def XLA_BroadcastOp : XLA_Op<"broadcast",
ElementsAttr:$broadcast_sizes ElementsAttr:$broadcast_sizes
); );
let results = (outs XLA_Tensor:$res); let results = (outs XLA_Tensor);
// TODO(b/129012527) These should be expressed as type constraints. // TODO(b/129012527) These should be expressed as type constraints.
let verifier = [{ let verifier = [{
...@@ -548,7 +548,7 @@ def XLA_BroadcastOp : XLA_Op<"broadcast", ...@@ -548,7 +548,7 @@ def XLA_BroadcastOp : XLA_Op<"broadcast",
"broadcast_sizes has rank {0} instead of rank 1", sizesRank)); "broadcast_sizes has rank {0} instead of rank 1", sizesRank));
} }
auto resultType = res()->getType().cast<RankedTensorType>(); auto resultType = getResult()->getType().cast<RankedTensorType>();
auto resultRank = resultType.getRank(); auto resultRank = resultType.getRank();
auto operandType = operand()->getType().cast<RankedTensorType>(); auto operandType = operand()->getType().cast<RankedTensorType>();
auto operandRank = operandType.getRank(); auto operandRank = operandType.getRank();
...@@ -613,7 +613,7 @@ def XLA_BroadcastInDimOp : XLA_Op<"broadcast_in_dim", ...@@ -613,7 +613,7 @@ def XLA_BroadcastInDimOp : XLA_Op<"broadcast_in_dim",
BroadcastDimAttr:$broadcast_dimensions BroadcastDimAttr:$broadcast_dimensions
); );
let results = (outs XLA_Tensor:$res); let results = (outs XLA_Tensor);
// TODO(b/129012527) These should be expressed as type constraints. // TODO(b/129012527) These should be expressed as type constraints.
let verifier = [{ let verifier = [{
...@@ -651,7 +651,7 @@ def XLA_BroadcastInDimOp : XLA_Op<"broadcast_in_dim", ...@@ -651,7 +651,7 @@ def XLA_BroadcastInDimOp : XLA_Op<"broadcast_in_dim",
dimensionsSize, operandRank)); dimensionsSize, operandRank));
} }
auto resultType = res()->getType().cast<RankedTensorType>(); auto resultType = getResult()->getType().cast<RankedTensorType>();
auto resultRank = resultType.getRank(); auto resultRank = resultType.getRank();
if (resultRank < operandRank) { if (resultRank < operandRank) {
return emitOpError( return emitOpError(
...@@ -706,9 +706,7 @@ def XLA_ClampOp : XLA_Op<"clamp", ...@@ -706,9 +706,7 @@ def XLA_ClampOp : XLA_Op<"clamp",
XLA_Tensor:$max XLA_Tensor:$max
); );
let results = (outs let results = (outs XLA_Tensor);
XLA_Tensor:$res
);
// TODO(b/129012527) These should be expressed as type constraints. // TODO(b/129012527) These should be expressed as type constraints.
let verifier = [{ let verifier = [{
...@@ -783,7 +781,7 @@ def XLA_ConcatenateOp : XLA_Op<"concatenate", ...@@ -783,7 +781,7 @@ def XLA_ConcatenateOp : XLA_Op<"concatenate",
return success(); return success();
}]; }];
let results = (outs XLA_Tensor:$res); let results = (outs XLA_Tensor);
// TODO(b/129422361) ConcatOp has special conversion logic to HLO. // TODO(b/129422361) ConcatOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
...@@ -803,20 +801,23 @@ def XLA_ConvOp : XLA_Op<"conv", [NoSideEffect]> { ...@@ -803,20 +801,23 @@ def XLA_ConvOp : XLA_Op<"conv", [NoSideEffect]> {
XLA_Tensor:$rhs XLA_Tensor:$rhs
); );
let results = (outs XLA_Tensor:$res); let results = (outs XLA_Tensor);
// TODO(b/129422361) Needs additional work to handle attributes. // TODO(b/129422361) Needs additional work to handle attributes.
// Conv has custom handling because its other args are passed as attributes // Conv has custom handling because its other args are passed as attributes
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
} }
def XLA_CopyOp: XLA_UnaryElementwiseOp<"copy", [NoSideEffect, SameOperandsAndResultType]> { def XLA_CopyOp: XLA_Op<"copy", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Copy operator"; let summary = "Copy operator";
let description = [{ let description = [{
Returns a copy of `operand`. Returns a copy of `operand`.
}]; }];
let arguments = (ins XLA_Tensor);
let results = (outs XLA_Tensor);
// TODO(b/129422361) Implement special handling. // TODO(b/129422361) Implement special handling.
// Copy has an HloOpcode, but is not one of the ops defined in xla_builder. // Copy has an HloOpcode, but is not one of the ops defined in xla_builder.
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
...@@ -828,7 +829,7 @@ def XLA_DotOp: XLA_Op<"dot", [NoSideEffect]> { ...@@ -828,7 +829,7 @@ def XLA_DotOp: XLA_Op<"dot", [NoSideEffect]> {
XLA_Tensor:$rhs, XLA_Tensor:$rhs,
XLA_PrecisionConfigAttr:$precision_config XLA_PrecisionConfigAttr:$precision_config
); );
let results = (outs XLA_Tensor:$res); let results = (outs XLA_Tensor);
let description = [{ let description = [{
Performs dot products between vectors, vector/matrix and matrix/matrix Performs dot products between vectors, vector/matrix and matrix/matrix
...@@ -849,7 +850,7 @@ def XLA_GatherOp: XLA_Op<"gather", [NoSideEffect]> { ...@@ -849,7 +850,7 @@ def XLA_GatherOp: XLA_Op<"gather", [NoSideEffect]> {
ElementsAttr: $start_index_map ElementsAttr: $start_index_map
); );
let results = (outs XLA_Tensor:$res); let results = (outs XLA_Tensor);
let summary = "Gather operator"; let summary = "Gather operator";
...@@ -868,7 +869,7 @@ def XLA_ReshapeOp: XLA_Op<"reshape", ...@@ -868,7 +869,7 @@ def XLA_ReshapeOp: XLA_Op<"reshape",
[NoSideEffect, SameOperandsAndResultElementType]> { [NoSideEffect, SameOperandsAndResultElementType]> {
let arguments = (ins XLA_Tensor:$operand); let arguments = (ins XLA_Tensor:$operand);
let results = (outs XLA_Tensor:$res); let results = (outs XLA_Tensor);
let summary = "Reshape operator"; let summary = "Reshape operator";
...@@ -909,7 +910,7 @@ def XLA_SelectOp: XLA_Op<"select", [NoSideEffect]> { ...@@ -909,7 +910,7 @@ def XLA_SelectOp: XLA_Op<"select", [NoSideEffect]> {
XLA_Tensor:$on_false XLA_Tensor:$on_false
); );
let results = (outs XLA_Tensor:$res); let results = (outs XLA_Tensor);
// TODO(b/129012527) These should be expressed as type constraints. // TODO(b/129012527) These should be expressed as type constraints.
let verifier = [{ let verifier = [{
...@@ -956,7 +957,7 @@ def XLA_ReverseOp: XLA_Op<"reverse", ...@@ -956,7 +957,7 @@ def XLA_ReverseOp: XLA_Op<"reverse",
ElementsAttr:$dimensions ElementsAttr:$dimensions
); );
let results = (outs XLA_Tensor:$res); let results = (outs XLA_Tensor);
// TODO(b/129422361): ReverseOp has a custom constructor for HLO. // TODO(b/129422361): ReverseOp has a custom constructor for HLO.
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
...@@ -981,7 +982,7 @@ def XLA_PadOp: XLA_Op<"pad", ...@@ -981,7 +982,7 @@ def XLA_PadOp: XLA_Op<"pad",
ElementsAttr: $interior_padding ElementsAttr: $interior_padding
); );
let results = (outs XLA_Tensor:$res); let results = (outs XLA_Tensor);
let description = [{ let description = [{
Pads the `operand` according to TBD. Pads the `operand` according to TBD.
...@@ -1052,7 +1053,7 @@ def XLA_TransposeOp: XLA_Op<"transpose", ...@@ -1052,7 +1053,7 @@ def XLA_TransposeOp: XLA_Op<"transpose",
XLA_Tensor:$operand, XLA_Tensor:$operand,
ElementsAttr:$permutation ElementsAttr:$permutation
); );
let results = (outs XLA_Tensor:$res); let results = (outs XLA_Tensor);
// TODO(b/129012527) These should be expressed as type constraints. // TODO(b/129012527) These should be expressed as type constraints.
let verifier = [{ let verifier = [{
...@@ -1078,7 +1079,7 @@ def XLA_TransposeOp: XLA_Op<"transpose", ...@@ -1078,7 +1079,7 @@ def XLA_TransposeOp: XLA_Op<"transpose",
permutationSize, operandRank)); permutationSize, operandRank));
} }
auto resultType = res()->getType().cast<RankedTensorType>(); auto resultType = getResult()->getType().cast<RankedTensorType>();
auto resultRank = resultType.getRank(); auto resultRank = resultType.getRank();
if (resultRank != operandRank) { if (resultRank != operandRank) {
return emitOpError( return emitOpError(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册