提交 a80e13e1 编写于 作者: P psandoz

8025534: Unsafe typecast in java.util.stream.Streams.Nodes

8025538: Unsafe typecast in java.util.stream.SpinedBuffer
8025533: Unsafe typecast in java.util.stream.Streams.RangeIntSpliterator.splitPoint()
8025525: Unsafe typecast in java.util.stream.Node.OfPrimitive.asArray()
Reviewed-by: chegar
上级 10d7bbf7
...@@ -149,7 +149,9 @@ interface Node<T> { ...@@ -149,7 +149,9 @@ interface Node<T> {
/** /**
* Copies the content of this {@code Node} into an array, starting at a * Copies the content of this {@code Node} into an array, starting at a
* given offset into the array. It is the caller's responsibility to ensure * given offset into the array. It is the caller's responsibility to ensure
* there is sufficient room in the array. * there is sufficient room in the array, otherwise unspecified behaviour
* will occur if the array length is less than the number of elements
* contained in this node.
* *
* @param array the array into which to copy the contents of this * @param array the array into which to copy the contents of this
* {@code Node} * {@code Node}
...@@ -258,6 +260,12 @@ interface Node<T> { ...@@ -258,6 +260,12 @@ interface Node<T> {
*/ */
@Override @Override
default T[] asArray(IntFunction<T[]> generator) { default T[] asArray(IntFunction<T[]> generator) {
if (java.util.stream.Tripwire.ENABLED)
java.util.stream.Tripwire.trip(getClass(), "{0} calling Node.OfPrimitive.asArray");
long size = count();
if (size >= Nodes.MAX_ARRAY_SIZE)
throw new IllegalArgumentException(Nodes.BAD_SIZE);
T[] boxed = generator.apply((int) count()); T[] boxed = generator.apply((int) count());
copyInto(boxed, 0); copyInto(boxed, 0);
return boxed; return boxed;
......
...@@ -60,6 +60,9 @@ final class Nodes { ...@@ -60,6 +60,9 @@ final class Nodes {
*/ */
static final long MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8; static final long MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8;
// IllegalArgumentException messages
static final String BAD_SIZE = "Stream size exceeds max array size";
@SuppressWarnings("raw") @SuppressWarnings("raw")
private static final Node EMPTY_NODE = new EmptyNode.OfRef(); private static final Node EMPTY_NODE = new EmptyNode.OfRef();
private static final Node.OfInt EMPTY_INT_NODE = new EmptyNode.OfInt(); private static final Node.OfInt EMPTY_INT_NODE = new EmptyNode.OfInt();
...@@ -317,7 +320,7 @@ final class Nodes { ...@@ -317,7 +320,7 @@ final class Nodes {
long size = helper.exactOutputSizeIfKnown(spliterator); long size = helper.exactOutputSizeIfKnown(spliterator);
if (size >= 0 && spliterator.hasCharacteristics(Spliterator.SUBSIZED)) { if (size >= 0 && spliterator.hasCharacteristics(Spliterator.SUBSIZED)) {
if (size >= MAX_ARRAY_SIZE) if (size >= MAX_ARRAY_SIZE)
throw new IllegalArgumentException("Stream size exceeds max array size"); throw new IllegalArgumentException(BAD_SIZE);
P_OUT[] array = generator.apply((int) size); P_OUT[] array = generator.apply((int) size);
new SizedCollectorTask.OfRef<>(spliterator, helper, array).invoke(); new SizedCollectorTask.OfRef<>(spliterator, helper, array).invoke();
return node(array); return node(array);
...@@ -354,7 +357,7 @@ final class Nodes { ...@@ -354,7 +357,7 @@ final class Nodes {
long size = helper.exactOutputSizeIfKnown(spliterator); long size = helper.exactOutputSizeIfKnown(spliterator);
if (size >= 0 && spliterator.hasCharacteristics(Spliterator.SUBSIZED)) { if (size >= 0 && spliterator.hasCharacteristics(Spliterator.SUBSIZED)) {
if (size >= MAX_ARRAY_SIZE) if (size >= MAX_ARRAY_SIZE)
throw new IllegalArgumentException("Stream size exceeds max array size"); throw new IllegalArgumentException(BAD_SIZE);
int[] array = new int[(int) size]; int[] array = new int[(int) size];
new SizedCollectorTask.OfInt<>(spliterator, helper, array).invoke(); new SizedCollectorTask.OfInt<>(spliterator, helper, array).invoke();
return node(array); return node(array);
...@@ -392,7 +395,7 @@ final class Nodes { ...@@ -392,7 +395,7 @@ final class Nodes {
long size = helper.exactOutputSizeIfKnown(spliterator); long size = helper.exactOutputSizeIfKnown(spliterator);
if (size >= 0 && spliterator.hasCharacteristics(Spliterator.SUBSIZED)) { if (size >= 0 && spliterator.hasCharacteristics(Spliterator.SUBSIZED)) {
if (size >= MAX_ARRAY_SIZE) if (size >= MAX_ARRAY_SIZE)
throw new IllegalArgumentException("Stream size exceeds max array size"); throw new IllegalArgumentException(BAD_SIZE);
long[] array = new long[(int) size]; long[] array = new long[(int) size];
new SizedCollectorTask.OfLong<>(spliterator, helper, array).invoke(); new SizedCollectorTask.OfLong<>(spliterator, helper, array).invoke();
return node(array); return node(array);
...@@ -430,7 +433,7 @@ final class Nodes { ...@@ -430,7 +433,7 @@ final class Nodes {
long size = helper.exactOutputSizeIfKnown(spliterator); long size = helper.exactOutputSizeIfKnown(spliterator);
if (size >= 0 && spliterator.hasCharacteristics(Spliterator.SUBSIZED)) { if (size >= 0 && spliterator.hasCharacteristics(Spliterator.SUBSIZED)) {
if (size >= MAX_ARRAY_SIZE) if (size >= MAX_ARRAY_SIZE)
throw new IllegalArgumentException("Stream size exceeds max array size"); throw new IllegalArgumentException(BAD_SIZE);
double[] array = new double[(int) size]; double[] array = new double[(int) size];
new SizedCollectorTask.OfDouble<>(spliterator, helper, array).invoke(); new SizedCollectorTask.OfDouble<>(spliterator, helper, array).invoke();
return node(array); return node(array);
...@@ -460,7 +463,10 @@ final class Nodes { ...@@ -460,7 +463,10 @@ final class Nodes {
*/ */
public static <T> Node<T> flatten(Node<T> node, IntFunction<T[]> generator) { public static <T> Node<T> flatten(Node<T> node, IntFunction<T[]> generator) {
if (node.getChildCount() > 0) { if (node.getChildCount() > 0) {
T[] array = generator.apply((int) node.count()); long size = node.count();
if (size >= MAX_ARRAY_SIZE)
throw new IllegalArgumentException(BAD_SIZE);
T[] array = generator.apply((int) size);
new ToArrayTask.OfRef<>(node, array, 0).invoke(); new ToArrayTask.OfRef<>(node, array, 0).invoke();
return node(array); return node(array);
} else { } else {
...@@ -483,7 +489,10 @@ final class Nodes { ...@@ -483,7 +489,10 @@ final class Nodes {
*/ */
public static Node.OfInt flattenInt(Node.OfInt node) { public static Node.OfInt flattenInt(Node.OfInt node) {
if (node.getChildCount() > 0) { if (node.getChildCount() > 0) {
int[] array = new int[(int) node.count()]; long size = node.count();
if (size >= MAX_ARRAY_SIZE)
throw new IllegalArgumentException(BAD_SIZE);
int[] array = new int[(int) size];
new ToArrayTask.OfInt(node, array, 0).invoke(); new ToArrayTask.OfInt(node, array, 0).invoke();
return node(array); return node(array);
} else { } else {
...@@ -506,7 +515,10 @@ final class Nodes { ...@@ -506,7 +515,10 @@ final class Nodes {
*/ */
public static Node.OfLong flattenLong(Node.OfLong node) { public static Node.OfLong flattenLong(Node.OfLong node) {
if (node.getChildCount() > 0) { if (node.getChildCount() > 0) {
long[] array = new long[(int) node.count()]; long size = node.count();
if (size >= MAX_ARRAY_SIZE)
throw new IllegalArgumentException(BAD_SIZE);
long[] array = new long[(int) size];
new ToArrayTask.OfLong(node, array, 0).invoke(); new ToArrayTask.OfLong(node, array, 0).invoke();
return node(array); return node(array);
} else { } else {
...@@ -529,7 +541,10 @@ final class Nodes { ...@@ -529,7 +541,10 @@ final class Nodes {
*/ */
public static Node.OfDouble flattenDouble(Node.OfDouble node) { public static Node.OfDouble flattenDouble(Node.OfDouble node) {
if (node.getChildCount() > 0) { if (node.getChildCount() > 0) {
double[] array = new double[(int) node.count()]; long size = node.count();
if (size >= MAX_ARRAY_SIZE)
throw new IllegalArgumentException(BAD_SIZE);
double[] array = new double[(int) size];
new ToArrayTask.OfDouble(node, array, 0).invoke(); new ToArrayTask.OfDouble(node, array, 0).invoke();
return node(array); return node(array);
} else { } else {
...@@ -627,7 +642,7 @@ final class Nodes { ...@@ -627,7 +642,7 @@ final class Nodes {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
ArrayNode(long size, IntFunction<T[]> generator) { ArrayNode(long size, IntFunction<T[]> generator) {
if (size >= MAX_ARRAY_SIZE) if (size >= MAX_ARRAY_SIZE)
throw new IllegalArgumentException("Stream size exceeds max array size"); throw new IllegalArgumentException(BAD_SIZE);
this.array = generator.apply((int) size); this.array = generator.apply((int) size);
this.curSize = 0; this.curSize = 0;
} }
...@@ -777,12 +792,17 @@ final class Nodes { ...@@ -777,12 +792,17 @@ final class Nodes {
public void copyInto(T[] array, int offset) { public void copyInto(T[] array, int offset) {
Objects.requireNonNull(array); Objects.requireNonNull(array);
left.copyInto(array, offset); left.copyInto(array, offset);
// Cast to int is safe since it is the callers responsibility to
// ensure that there is sufficient room in the array
right.copyInto(array, offset + (int) left.count()); right.copyInto(array, offset + (int) left.count());
} }
@Override @Override
public T[] asArray(IntFunction<T[]> generator) { public T[] asArray(IntFunction<T[]> generator) {
T[] array = generator.apply((int) count()); long size = count();
if (size >= MAX_ARRAY_SIZE)
throw new IllegalArgumentException(BAD_SIZE);
T[] array = generator.apply((int) size);
copyInto(array, 0); copyInto(array, 0);
return array; return array;
} }
...@@ -836,12 +856,17 @@ final class Nodes { ...@@ -836,12 +856,17 @@ final class Nodes {
@Override @Override
public void copyInto(T_ARR array, int offset) { public void copyInto(T_ARR array, int offset) {
left.copyInto(array, offset); left.copyInto(array, offset);
// Cast to int is safe since it is the callers responsibility to
// ensure that there is sufficient room in the array
right.copyInto(array, offset + (int) left.count()); right.copyInto(array, offset + (int) left.count());
} }
@Override @Override
public T_ARR asPrimitiveArray() { public T_ARR asPrimitiveArray() {
T_ARR array = newArray((int) count()); long size = count();
if (size >= MAX_ARRAY_SIZE)
throw new IllegalArgumentException(BAD_SIZE);
T_ARR array = newArray((int) size);
copyInto(array, 0); copyInto(array, 0);
return array; return array;
} }
...@@ -1287,7 +1312,7 @@ final class Nodes { ...@@ -1287,7 +1312,7 @@ final class Nodes {
IntArrayNode(long size) { IntArrayNode(long size) {
if (size >= MAX_ARRAY_SIZE) if (size >= MAX_ARRAY_SIZE)
throw new IllegalArgumentException("Stream size exceeds max array size"); throw new IllegalArgumentException(BAD_SIZE);
this.array = new int[(int) size]; this.array = new int[(int) size];
this.curSize = 0; this.curSize = 0;
} }
...@@ -1343,7 +1368,7 @@ final class Nodes { ...@@ -1343,7 +1368,7 @@ final class Nodes {
LongArrayNode(long size) { LongArrayNode(long size) {
if (size >= MAX_ARRAY_SIZE) if (size >= MAX_ARRAY_SIZE)
throw new IllegalArgumentException("Stream size exceeds max array size"); throw new IllegalArgumentException(BAD_SIZE);
this.array = new long[(int) size]; this.array = new long[(int) size];
this.curSize = 0; this.curSize = 0;
} }
...@@ -1397,7 +1422,7 @@ final class Nodes { ...@@ -1397,7 +1422,7 @@ final class Nodes {
DoubleArrayNode(long size) { DoubleArrayNode(long size) {
if (size >= MAX_ARRAY_SIZE) if (size >= MAX_ARRAY_SIZE)
throw new IllegalArgumentException("Stream size exceeds max array size"); throw new IllegalArgumentException(BAD_SIZE);
this.array = new double[(int) size]; this.array = new double[(int) size];
this.curSize = 0; this.curSize = 0;
} }
...@@ -1843,8 +1868,8 @@ final class Nodes { ...@@ -1843,8 +1868,8 @@ final class Nodes {
task = task.makeChild(rightSplit, task.offset + leftSplitSize, task = task.makeChild(rightSplit, task.offset + leftSplitSize,
task.length - leftSplitSize); task.length - leftSplitSize);
} }
if (task.offset + task.length >= MAX_ARRAY_SIZE)
throw new IllegalArgumentException("Stream size exceeds max array size"); assert task.offset + task.length < MAX_ARRAY_SIZE;
T_SINK sink = (T_SINK) task; T_SINK sink = (T_SINK) task;
task.helper.wrapAndCopyInto(sink, rightSplit); task.helper.wrapAndCopyInto(sink, rightSplit);
task.propagateCompletion(); task.propagateCompletion();
...@@ -1854,10 +1879,13 @@ final class Nodes { ...@@ -1854,10 +1879,13 @@ final class Nodes {
@Override @Override
public void begin(long size) { public void begin(long size) {
if(size > length) if (size > length)
throw new IllegalStateException("size passed to Sink.begin exceeds array length"); throw new IllegalStateException("size passed to Sink.begin exceeds array length");
// Casts to int are safe since absolute size is verified to be within
// bounds when the root concrete SizedCollectorTask is constructed
// with the shared array
index = (int) offset; index = (int) offset;
fence = (int) offset + (int) length; fence = index + (int) length;
} }
@SuppressWarnings("serial") @SuppressWarnings("serial")
......
...@@ -277,8 +277,6 @@ final class SortedOps { ...@@ -277,8 +277,6 @@ final class SortedOps {
} }
} }
private static final String BAD_SIZE = "Stream size exceeds max array size";
/** /**
* {@link Sink} for implementing sort on SIZED reference streams. * {@link Sink} for implementing sort on SIZED reference streams.
*/ */
...@@ -295,7 +293,7 @@ final class SortedOps { ...@@ -295,7 +293,7 @@ final class SortedOps {
@Override @Override
public void begin(long size) { public void begin(long size) {
if (size >= Nodes.MAX_ARRAY_SIZE) if (size >= Nodes.MAX_ARRAY_SIZE)
throw new IllegalArgumentException(BAD_SIZE); throw new IllegalArgumentException(Nodes.BAD_SIZE);
array = (T[]) new Object[(int) size]; array = (T[]) new Object[(int) size];
} }
...@@ -330,7 +328,7 @@ final class SortedOps { ...@@ -330,7 +328,7 @@ final class SortedOps {
@Override @Override
public void begin(long size) { public void begin(long size) {
if (size >= Nodes.MAX_ARRAY_SIZE) if (size >= Nodes.MAX_ARRAY_SIZE)
throw new IllegalArgumentException(BAD_SIZE); throw new IllegalArgumentException(Nodes.BAD_SIZE);
list = (size >= 0) ? new ArrayList<T>((int) size) : new ArrayList<T>(); list = (size >= 0) ? new ArrayList<T>((int) size) : new ArrayList<T>();
} }
...@@ -363,7 +361,7 @@ final class SortedOps { ...@@ -363,7 +361,7 @@ final class SortedOps {
@Override @Override
public void begin(long size) { public void begin(long size) {
if (size >= Nodes.MAX_ARRAY_SIZE) if (size >= Nodes.MAX_ARRAY_SIZE)
throw new IllegalArgumentException(BAD_SIZE); throw new IllegalArgumentException(Nodes.BAD_SIZE);
array = new int[(int) size]; array = new int[(int) size];
} }
...@@ -396,7 +394,7 @@ final class SortedOps { ...@@ -396,7 +394,7 @@ final class SortedOps {
@Override @Override
public void begin(long size) { public void begin(long size) {
if (size >= Nodes.MAX_ARRAY_SIZE) if (size >= Nodes.MAX_ARRAY_SIZE)
throw new IllegalArgumentException(BAD_SIZE); throw new IllegalArgumentException(Nodes.BAD_SIZE);
b = (size > 0) ? new SpinedBuffer.OfInt((int) size) : new SpinedBuffer.OfInt(); b = (size > 0) ? new SpinedBuffer.OfInt((int) size) : new SpinedBuffer.OfInt();
} }
...@@ -430,7 +428,7 @@ final class SortedOps { ...@@ -430,7 +428,7 @@ final class SortedOps {
@Override @Override
public void begin(long size) { public void begin(long size) {
if (size >= Nodes.MAX_ARRAY_SIZE) if (size >= Nodes.MAX_ARRAY_SIZE)
throw new IllegalArgumentException(BAD_SIZE); throw new IllegalArgumentException(Nodes.BAD_SIZE);
array = new long[(int) size]; array = new long[(int) size];
} }
...@@ -463,7 +461,7 @@ final class SortedOps { ...@@ -463,7 +461,7 @@ final class SortedOps {
@Override @Override
public void begin(long size) { public void begin(long size) {
if (size >= Nodes.MAX_ARRAY_SIZE) if (size >= Nodes.MAX_ARRAY_SIZE)
throw new IllegalArgumentException(BAD_SIZE); throw new IllegalArgumentException(Nodes.BAD_SIZE);
b = (size > 0) ? new SpinedBuffer.OfLong((int) size) : new SpinedBuffer.OfLong(); b = (size > 0) ? new SpinedBuffer.OfLong((int) size) : new SpinedBuffer.OfLong();
} }
...@@ -497,7 +495,7 @@ final class SortedOps { ...@@ -497,7 +495,7 @@ final class SortedOps {
@Override @Override
public void begin(long size) { public void begin(long size) {
if (size >= Nodes.MAX_ARRAY_SIZE) if (size >= Nodes.MAX_ARRAY_SIZE)
throw new IllegalArgumentException(BAD_SIZE); throw new IllegalArgumentException(Nodes.BAD_SIZE);
array = new double[(int) size]; array = new double[(int) size];
} }
...@@ -530,7 +528,7 @@ final class SortedOps { ...@@ -530,7 +528,7 @@ final class SortedOps {
@Override @Override
public void begin(long size) { public void begin(long size) {
if (size >= Nodes.MAX_ARRAY_SIZE) if (size >= Nodes.MAX_ARRAY_SIZE)
throw new IllegalArgumentException(BAD_SIZE); throw new IllegalArgumentException(Nodes.BAD_SIZE);
b = (size > 0) ? new SpinedBuffer.OfDouble((int) size) : new SpinedBuffer.OfDouble(); b = (size > 0) ? new SpinedBuffer.OfDouble((int) size) : new SpinedBuffer.OfDouble();
} }
......
...@@ -156,6 +156,9 @@ class SpinedBuffer<E> ...@@ -156,6 +156,9 @@ class SpinedBuffer<E>
public E get(long index) { public E get(long index) {
// @@@ can further optimize by caching last seen spineIndex, // @@@ can further optimize by caching last seen spineIndex,
// which is going to be right most of the time // which is going to be right most of the time
// Casts to int are safe since the spine array index is the index minus
// the prior element count from the current spine
if (spineIndex == 0) { if (spineIndex == 0) {
if (index < elementIndex) if (index < elementIndex)
return curChunk[((int) index)]; return curChunk[((int) index)];
...@@ -201,11 +204,11 @@ class SpinedBuffer<E> ...@@ -201,11 +204,11 @@ class SpinedBuffer<E>
* elements into it. * elements into it.
*/ */
public E[] asArray(IntFunction<E[]> arrayFactory) { public E[] asArray(IntFunction<E[]> arrayFactory) {
// @@@ will fail for size == MAX_VALUE long size = count();
E[] result = arrayFactory.apply((int) count()); if (size >= Nodes.MAX_ARRAY_SIZE)
throw new IllegalArgumentException(Nodes.BAD_SIZE);
E[] result = arrayFactory.apply((int) size);
copyInto(result, 0); copyInto(result, 0);
return result; return result;
} }
...@@ -547,8 +550,10 @@ class SpinedBuffer<E> ...@@ -547,8 +550,10 @@ class SpinedBuffer<E>
} }
public T_ARR asPrimitiveArray() { public T_ARR asPrimitiveArray() {
// @@@ will fail for size == MAX_VALUE long size = count();
T_ARR result = newArray((int) count()); if (size >= Nodes.MAX_ARRAY_SIZE)
throw new IllegalArgumentException(Nodes.BAD_SIZE);
T_ARR result = newArray((int) size);
copyInto(result, 0); copyInto(result, 0);
return result; return result;
} }
...@@ -760,11 +765,13 @@ class SpinedBuffer<E> ...@@ -760,11 +765,13 @@ class SpinedBuffer<E>
} }
public int get(long index) { public int get(long index) {
// Casts to int are safe since the spine array index is the index minus
// the prior element count from the current spine
int ch = chunkFor(index); int ch = chunkFor(index);
if (spineIndex == 0 && ch == 0) if (spineIndex == 0 && ch == 0)
return curChunk[(int) index]; return curChunk[(int) index];
else else
return spine[ch][(int) (index-priorElementCount[ch])]; return spine[ch][(int) (index - priorElementCount[ch])];
} }
@Override @Override
...@@ -871,11 +878,13 @@ class SpinedBuffer<E> ...@@ -871,11 +878,13 @@ class SpinedBuffer<E>
} }
public long get(long index) { public long get(long index) {
// Casts to int are safe since the spine array index is the index minus
// the prior element count from the current spine
int ch = chunkFor(index); int ch = chunkFor(index);
if (spineIndex == 0 && ch == 0) if (spineIndex == 0 && ch == 0)
return curChunk[(int) index]; return curChunk[(int) index];
else else
return spine[ch][(int) (index-priorElementCount[ch])]; return spine[ch][(int) (index - priorElementCount[ch])];
} }
@Override @Override
...@@ -984,11 +993,13 @@ class SpinedBuffer<E> ...@@ -984,11 +993,13 @@ class SpinedBuffer<E>
} }
public double get(long index) { public double get(long index) {
// Casts to int are safe since the spine array index is the index minus
// the prior element count from the current spine
int ch = chunkFor(index); int ch = chunkFor(index);
if (spineIndex == 0 && ch == 0) if (spineIndex == 0 && ch == 0)
return curChunk[(int) index]; return curChunk[(int) index];
else else
return spine[ch][(int) (index-priorElementCount[ch])]; return spine[ch][(int) (index - priorElementCount[ch])];
} }
@Override @Override
......
...@@ -169,7 +169,9 @@ final class Streams { ...@@ -169,7 +169,9 @@ final class Streams {
private int splitPoint(long size) { private int splitPoint(long size) {
int d = (size < BALANCED_SPLIT_THRESHOLD) ? 2 : RIGHT_BALANCED_SPLIT_RATIO; int d = (size < BALANCED_SPLIT_THRESHOLD) ? 2 : RIGHT_BALANCED_SPLIT_RATIO;
// 2 <= size <= 2^32 // Cast to int is safe since:
// 2 <= size < 2^32
// 2 <= d <= 8
return (int) (size / d); return (int) (size / d);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册