提交 5b604850 编写于 作者: A Alexander Belyaev 提交者: TensorFlower Gardener

[MLIR] Fix incorrect usage of kDynamicSize constant.

Prepare to change kDynamicSize to int64_t::min.

PiperOrigin-RevId: 481104488
上级 4f2c1bcc
......@@ -79,7 +79,8 @@ TFL::ReshapeOp InsertReshapeOp(Location loc, Value input, Type element_type,
// TODO(renjieliu): Revisit this later.
SmallVector<int32_t, 4> new_shape_array_i32;
for (auto size : new_shape_array) {
new_shape_array_i32.push_back(static_cast<int32_t>(size));
new_shape_array_i32.push_back(
ShapedType::isDynamic(size) ? -1 : static_cast<int32_t>(size));
}
auto new_shape_attr =
mlir::DenseIntElementsAttr::get(reshape_shape_type, new_shape_array_i32);
......@@ -103,7 +104,7 @@ LogicalResult EnsureBias(Operation* op, int bias_idx,
if (!output_type) return failure();
// bias should be a vector sized of the last output dim.
int num_units = output_type.getDimSize(output_type.getRank() - 1);
int64_t num_units = output_type.getDimSize(output_type.getRank() - 1);
auto bias_type =
mlir::RankedTensorType::get({num_units}, output_type.getElementType());
......@@ -167,7 +168,7 @@ SmallVector<Value, 4> SliceOutputs(Operation* split_op, Value input,
if (d == split_dim) {
// Split dimension.
slice_begin.push_back(begin);
int size = current_output_type.getDimSize(d);
int64_t size = current_output_type.getDimSize(d);
slice_size.push_back(size);
begin += size;
} else {
......@@ -214,16 +215,16 @@ LogicalResult LowerPackIntoConcatReshape::matchAndRewrite(
SmallVector<int64_t, 4> concat_out_shape;
SmallVector<int64_t, 4> pack_out_shape;
const int rank = input_type.getRank();
int pack_axis = pack_op.axis();
int count = pack_inputs.size();
const int64_t rank = input_type.getRank();
int64_t pack_axis = pack_op.axis();
size_t count = pack_inputs.size();
if (pack_axis < 0) {
pack_axis += rank;
}
// Concat out shape.
for (int i = 0; i < rank; ++i) {
int dim_size = input_type.getDimSize(i);
int64_t dim_size = input_type.getDimSize(i);
if (i == pack_axis) {
dim_size *= count;
}
......@@ -301,8 +302,8 @@ LogicalResult UnrollSplit::matchAndRewrite(TFL::SplitOp split_op,
// TODO(renjieliu): change to use split_dim when we raise the constants
// as well.
int split_dim = -1;
for (int d = 0; d < input_type.getRank(); ++d) {
int64_t split_dim = -1;
for (int64_t d = 0; d < input_type.getRank(); ++d) {
if (input_type.getDimSize(d) != output_type.getDimSize(d)) split_dim = d;
}
......@@ -336,7 +337,7 @@ LogicalResult UnrollSplitV::matchAndRewrite(TFL::SplitVOp splitv_op,
if (result_type == nullptr) return failure();
}
const int rank = input_type.getRank();
const int64_t rank = input_type.getRank();
IntegerAttr dim_int = ExtractSingleElementAsInteger(split_dim_attr);
......@@ -486,9 +487,9 @@ LogicalResult FullyConnectedToConv::matchAndRewrite(
// Insert a reshape after the input.
// Since the input maybe more than 2-d, we may collect the flat size of the
// input then reshape into [1, 1, flat_size / depth, depth].
const int depth = input_type.getDimSize(input_type.getRank() - 1);
const int flat_size = input_type.getNumElements();
const int width = flat_size / depth;
const int64_t depth = input_type.getDimSize(input_type.getRank() - 1);
const int64_t flat_size = input_type.getNumElements();
const int64_t width = flat_size / depth;
SmallVector<int64_t, 4> input_new_shape({1, 1, width, depth});
auto reshaped_input =
InsertReshapeOp(fc_op.getLoc(), input, input_type.getElementType(),
......@@ -496,7 +497,7 @@ LogicalResult FullyConnectedToConv::matchAndRewrite(
// Insert a reshape after the weight.
// We will reshape the weight into [output, 1, 1, depth]
const int output_size = weight_type.getDimSize(0);
const int64_t output_size = weight_type.getDimSize(0);
SmallVector<int64_t, 2> weight_new_shape({output_size, 1, 1, depth});
auto reshaped_weight =
InsertReshapeOp(fc_op.getLoc(), weight, weight_type.getElementType(),
......@@ -629,10 +630,10 @@ LogicalResult ReduceMeanToAvgPool::matchAndRewrite(
return failure();
}
int batch = input_type.getDimSize(0);
int height = input_type.getDimSize(1);
int width = input_type.getDimSize(2);
int channel = input_type.getDimSize(3);
int64_t batch = input_type.getDimSize(0);
int64_t height = input_type.getDimSize(1);
int64_t width = input_type.getDimSize(2);
int64_t channel = input_type.getDimSize(3);
auto avg_pool_output_type = RankedTensorType::get(
{batch, 1, 1, channel}, input_type.getElementType());
......
......@@ -298,11 +298,12 @@ TypeAttr RescaleQtype(Type input, Attribute factor) {
// Precondition: output_val's is ranked tensor.
DenseElementsAttr GetShape(Value output_val) {
auto output_type = output_val.getType().cast<RankedTensorType>();
auto shape_vector = output_type.getShape();
std::vector<int32_t> shape;
shape.reserve(shape_vector.size());
for (auto shape_object : shape_vector) {
shape.push_back(shape_object);
SmallVector<int32_t> shape;
shape.reserve(output_type.getRank());
for (int64_t dim : output_type.getShape()) {
shape.push_back(ShapedType::isDynamic(dim) ? -1
: static_cast<int32_t>(dim));
}
return mlir::DenseElementsAttr::get(
RankedTensorType::get(
......@@ -1325,7 +1326,8 @@ struct ConvertTrivialTransposeOpToReshapeOp
SmallVector<int32_t, 8> output_shape_values;
for (auto dim : output_type.getShape()) {
output_shape_values.push_back(dim);
output_shape_values.push_back(
ShapedType::isDynamic(dim) ? -1 : static_cast<int32_t>(dim));
}
auto type = mlir::RankedTensorType::get(output_shape_values.size(),
rewriter.getIntegerType(32));
......@@ -1495,7 +1497,8 @@ struct FuseUnpackAndConcatToReshape
// This is to workaround the unnecessary cast i64 -> i32.
SmallVector<int32_t, 4> new_shape_array_i32;
for (auto size : new_shape_array) {
new_shape_array_i32.push_back(static_cast<int32_t>(size));
new_shape_array_i32.push_back(
ShapedType::isDynamic(size) ? -1 : static_cast<int32_t>(size));
}
auto new_shape = rewriter.create<TFL::ConstOp>(
concat_op.getLoc(),
......
......@@ -71,7 +71,7 @@ template <typename ShapeContainerT>
void SetTensorShapeProto(ShapeContainerT shape, TensorShapeProto* proto) {
if (shape.hasRank()) {
for (int64_t dim : shape.getShape()) {
proto->add_dim()->set_size(dim);
proto->add_dim()->set_size(mlir::ShapedType::isDynamic(dim) ? -1 : dim);
}
} else {
proto->set_unknown_rank(true);
......
......@@ -3510,7 +3510,7 @@ class ConvertSplitVOp : public OpRewritePattern<TF::SplitVOp> {
for (auto dim : llvm::enumerate(split_sizes_attr)) {
int64_t dim_val = dim.value().getSExtValue();
split_sizes.push_back(dim_val);
if (dim_val == ShapedType::kDynamicSize) {
if (dim_val == -1) {
// We cannot have more than one dynamic dimension.
assert(!dynamic_dim_index && "invalid split sizes");
dynamic_dim_index = dim.index();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册