提交 8b75ad10 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Add InsertSlicesOp to the VectorOps dialect.

PiperOrigin-RevId: 285830394
Change-Id: Iea5547d55efedd5150c82edf7afb5e4ed2b373d3
上级 071ae24b
......@@ -352,7 +352,7 @@ def Vector_ExtractSlicesOp :
%3 = vector.extract_slices %2, [2, 2], [1, 1]
: vector<4x3xf32> into tuple<vector<2x2xf32>, vector<2x1xf32>,
vector<2x2xf32>, vector<2x2xf32>>
vector<2x2xf32>, vector<2x1xf32>>
```
}];
let builders = [OpBuilder<
......@@ -439,6 +439,58 @@ def Vector_InsertOp :
}];
}
def Vector_InsertSlicesOp :
Vector_Op<"insert_slices", [NoSideEffect]>,
Arguments<(ins TupleOf<[AnyVector]>:$vectors, I64ArrayAttr:$sizes,
I64ArrayAttr:$strides)>,
Results<(outs AnyVector)> {
let summary = "vector insert slices operation";
let description = [{
Takes a tuple of vector slices and inserts them into the vector result
according to the 'sizes' and 'strides' parameters.
The arguments 'sizes' and 'strides' represent a specification for
generating the unrolling of 'vector' shape, which has all slices of shape
'sizes' except for slices at dimension boundaries when 'vector' dimension
sizes are not a multiple of 'sizes'.
Each slice in 'vectors' is at the tuple element index corresponding to the
linear index of the slice w.r.t the unrolling scheme represented by 'sizes'.
Currently, only unit strides are supported.
Examples:
```
%0 = vector.extract_slices %0, [2, 2], [1, 1]
: vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
%1 = vector.insert_slices %0, [2, 2], [1, 1]
: tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<4x2xf32>
// Example with partial slices at dimension boundaries.
%3 = vector.extract_slices %2, [2, 2], [1, 1]
: vector<4x3xf32> into tuple<vector<2x2xf32>, vector<2x1xf32>,
vector<2x2xf32>, vector<2x1xf32>>
%4 = vector.insert_slices %3, [2, 2], [1, 1]
: tuple<vector<2x2xf32>, vector<2x1xf32>,
vector<2x2xf32>, vector<2x1xf32>> into vector<4x3xf32>
```
}];
let extraClassDeclaration = [{
TupleType getSourceTupleType() {
return vectors()->getType().cast<TupleType>();
}
VectorType getResultVectorType() {
return getResult()->getType().cast<VectorType>();
}
void getSizes(SmallVectorImpl<int64_t> &results);
void getStrides(SmallVectorImpl<int64_t> &results);
static StringRef getSizesAttrName() { return "sizes"; }
static StringRef getStridesAttrName() { return "strides"; }
}];
}
def Vector_InsertStridedSliceOp :
Vector_Op<"insert_strided_slice", [NoSideEffect,
PredOpTrait<"operand #0 and result have same element type",
......
......@@ -825,6 +825,60 @@ static LogicalResult verify(InsertOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// InsertSlicesOp
//===----------------------------------------------------------------------===//
static ParseResult parseInsertSlicesOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType operandInfo;
ArrayAttr sizesAttr;
StringRef sizesAttrName = InsertSlicesOp::getSizesAttrName();
ArrayAttr stridesAttr;
StringRef stridesAttrName = InsertSlicesOp::getStridesAttrName();
TupleType tupleType;
VectorType resultVectorType;
return failure(
parser.parseOperand(operandInfo) || parser.parseComma() ||
parser.parseAttribute(sizesAttr, sizesAttrName, result.attributes) ||
parser.parseComma() ||
parser.parseAttribute(stridesAttr, stridesAttrName, result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(tupleType) ||
parser.parseKeywordType("into", resultVectorType) ||
parser.resolveOperand(operandInfo, tupleType, result.operands) ||
parser.addTypeToList(resultVectorType, result.types));
}
static void print(OpAsmPrinter &p, InsertSlicesOp op) {
p << op.getOperationName() << ' ' << *op.vectors() << ", ";
p << op.sizes() << ", " << op.strides();
p.printOptionalAttrDict(
op.getAttrs(),
/*elidedAttrs=*/{InsertSlicesOp::getSizesAttrName(),
InsertSlicesOp::getStridesAttrName()});
p << " : " << op.vectors()->getType();
p << " into " << op.getResultVectorType();
}
static LogicalResult verify(InsertSlicesOp op) {
SmallVector<int64_t, 4> sizes;
op.getSizes(sizes);
SmallVector<int64_t, 4> strides;
op.getStrides(strides);
return isValidExtractOrInsertSlicesType(
op.getOperation(), op.getResultVectorType(), op.getSourceTupleType(),
sizes, strides);
}
void InsertSlicesOp::getSizes(SmallVectorImpl<int64_t> &results) {
populateFromInt64AttrArray(sizes(), results);
}
void InsertSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
populateFromInt64AttrArray(strides(), results);
}
//===----------------------------------------------------------------------===//
// InsertStridedSliceOp
//===----------------------------------------------------------------------===//
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册