diff --git a/src/share/classes/java/util/stream/SortedOps.java b/src/share/classes/java/util/stream/SortedOps.java index 8dcabb491452354c97c88e975cd196f746fabb7c..810de1ca20ed14c598d3543bdb1a43d0097a86b3 100644 --- a/src/share/classes/java/util/stream/SortedOps.java +++ b/src/share/classes/java/util/stream/SortedOps.java @@ -277,17 +277,61 @@ final class SortedOps { } } + /** + * Abstract {@link Sink} for implementing sort on reference streams. + * + *

+ * Note: documentation below applies to reference and all primitive sinks. + *

+ * Sorting sinks first accept all elements, buffering then into an array + * or a re-sizable data structure, if the size of the pipeline is known or + * unknown respectively. At the end of the sink protocol those elements are + * sorted and then pushed downstream. + * This class records if {@link #cancellationRequested} is called. If so it + * can be inferred that the source pushing source elements into the pipeline + * knows that the pipeline is short-circuiting. In such cases sub-classes + * pushing elements downstream will preserve the short-circuiting protocol + * by calling {@code downstream.cancellationRequested()} and checking the + * result is {@code false} before an element is pushed. + *

+ * Note that the above behaviour is an optimization for sorting with + * sequential streams. It is not an error that more elements, than strictly + * required to produce a result, may flow through the pipeline. This can + * occur, in general (not restricted to just sorting), for short-circuiting + * parallel pipelines. + */ + private static abstract class AbstractRefSortingSink extends Sink.ChainedReference { + protected final Comparator comparator; + // @@@ could be a lazy final value, if/when support is added + protected boolean cancellationWasRequested; + + AbstractRefSortingSink(Sink downstream, Comparator comparator) { + super(downstream); + this.comparator = comparator; + } + + /** + * Records is cancellation is requested so short-circuiting behaviour + * can be preserved when the sorted elements are pushed downstream. + * + * @return false, as this sink never short-circuits. + */ + @Override + public final boolean cancellationRequested() { + cancellationWasRequested = true; + return false; + } + } + /** * {@link Sink} for implementing sort on SIZED reference streams. */ - private static final class SizedRefSortingSink extends Sink.ChainedReference { - private final Comparator comparator; + private static final class SizedRefSortingSink extends AbstractRefSortingSink { private T[] array; private int offset; SizedRefSortingSink(Sink sink, Comparator comparator) { - super(sink); - this.comparator = comparator; + super(sink, comparator); } @Override @@ -301,8 +345,14 @@ final class SortedOps { public void end() { Arrays.sort(array, 0, offset, comparator); downstream.begin(offset); - for (int i = 0; i < offset; i++) - downstream.accept(array[i]); + if (!cancellationWasRequested) { + for (int i = 0; i < offset; i++) + downstream.accept(array[i]); + } + else { + for (int i = 0; i < offset && !downstream.cancellationRequested(); i++) + downstream.accept(array[i]); + } downstream.end(); array = null; } @@ -316,13 +366,11 @@ final class SortedOps { /** * {@link Sink} for implementing sort on reference streams. */ - private static final class RefSortingSink extends Sink.ChainedReference { - private final Comparator comparator; + private static final class RefSortingSink extends AbstractRefSortingSink { private ArrayList list; RefSortingSink(Sink sink, Comparator comparator) { - super(sink); - this.comparator = comparator; + super(sink, comparator); } @Override @@ -336,7 +384,15 @@ final class SortedOps { public void end() { list.sort(comparator); downstream.begin(list.size()); - list.forEach(downstream::accept); + if (!cancellationWasRequested) { + list.forEach(downstream::accept); + } + else { + for (T t : list) { + if (downstream.cancellationRequested()) break; + downstream.accept(t); + } + } downstream.end(); list = null; } @@ -347,10 +403,27 @@ final class SortedOps { } } + /** + * Abstract {@link Sink} for implementing sort on int streams. + */ + private static abstract class AbstractIntSortingSink extends Sink.ChainedInt { + protected boolean cancellationWasRequested; + + AbstractIntSortingSink(Sink downstream) { + super(downstream); + } + + @Override + public final boolean cancellationRequested() { + cancellationWasRequested = true; + return false; + } + } + /** * {@link Sink} for implementing sort on SIZED int streams. */ - private static final class SizedIntSortingSink extends Sink.ChainedInt { + private static final class SizedIntSortingSink extends AbstractIntSortingSink { private int[] array; private int offset; @@ -369,8 +442,14 @@ final class SortedOps { public void end() { Arrays.sort(array, 0, offset); downstream.begin(offset); - for (int i = 0; i < offset; i++) - downstream.accept(array[i]); + if (!cancellationWasRequested) { + for (int i = 0; i < offset; i++) + downstream.accept(array[i]); + } + else { + for (int i = 0; i < offset && !downstream.cancellationRequested(); i++) + downstream.accept(array[i]); + } downstream.end(); array = null; } @@ -384,7 +463,7 @@ final class SortedOps { /** * {@link Sink} for implementing sort on int streams. */ - private static final class IntSortingSink extends Sink.ChainedInt { + private static final class IntSortingSink extends AbstractIntSortingSink { private SpinedBuffer.OfInt b; IntSortingSink(Sink sink) { @@ -403,8 +482,16 @@ final class SortedOps { int[] ints = b.asPrimitiveArray(); Arrays.sort(ints); downstream.begin(ints.length); - for (int anInt : ints) - downstream.accept(anInt); + if (!cancellationWasRequested) { + for (int anInt : ints) + downstream.accept(anInt); + } + else { + for (int anInt : ints) { + if (downstream.cancellationRequested()) break; + downstream.accept(anInt); + } + } downstream.end(); } @@ -414,10 +501,27 @@ final class SortedOps { } } + /** + * Abstract {@link Sink} for implementing sort on long streams. + */ + private static abstract class AbstractLongSortingSink extends Sink.ChainedLong { + protected boolean cancellationWasRequested; + + AbstractLongSortingSink(Sink downstream) { + super(downstream); + } + + @Override + public final boolean cancellationRequested() { + cancellationWasRequested = true; + return false; + } + } + /** * {@link Sink} for implementing sort on SIZED long streams. */ - private static final class SizedLongSortingSink extends Sink.ChainedLong { + private static final class SizedLongSortingSink extends AbstractLongSortingSink { private long[] array; private int offset; @@ -436,8 +540,14 @@ final class SortedOps { public void end() { Arrays.sort(array, 0, offset); downstream.begin(offset); - for (int i = 0; i < offset; i++) - downstream.accept(array[i]); + if (!cancellationWasRequested) { + for (int i = 0; i < offset; i++) + downstream.accept(array[i]); + } + else { + for (int i = 0; i < offset && !downstream.cancellationRequested(); i++) + downstream.accept(array[i]); + } downstream.end(); array = null; } @@ -451,7 +561,7 @@ final class SortedOps { /** * {@link Sink} for implementing sort on long streams. */ - private static final class LongSortingSink extends Sink.ChainedLong { + private static final class LongSortingSink extends AbstractLongSortingSink { private SpinedBuffer.OfLong b; LongSortingSink(Sink sink) { @@ -470,8 +580,16 @@ final class SortedOps { long[] longs = b.asPrimitiveArray(); Arrays.sort(longs); downstream.begin(longs.length); - for (long aLong : longs) - downstream.accept(aLong); + if (!cancellationWasRequested) { + for (long aLong : longs) + downstream.accept(aLong); + } + else { + for (long aLong : longs) { + if (downstream.cancellationRequested()) break; + downstream.accept(aLong); + } + } downstream.end(); } @@ -481,10 +599,27 @@ final class SortedOps { } } + /** + * Abstract {@link Sink} for implementing sort on long streams. + */ + private static abstract class AbstractDoubleSortingSink extends Sink.ChainedDouble { + protected boolean cancellationWasRequested; + + AbstractDoubleSortingSink(Sink downstream) { + super(downstream); + } + + @Override + public final boolean cancellationRequested() { + cancellationWasRequested = true; + return false; + } + } + /** * {@link Sink} for implementing sort on SIZED double streams. */ - private static final class SizedDoubleSortingSink extends Sink.ChainedDouble { + private static final class SizedDoubleSortingSink extends AbstractDoubleSortingSink { private double[] array; private int offset; @@ -503,8 +638,14 @@ final class SortedOps { public void end() { Arrays.sort(array, 0, offset); downstream.begin(offset); - for (int i = 0; i < offset; i++) - downstream.accept(array[i]); + if (!cancellationWasRequested) { + for (int i = 0; i < offset; i++) + downstream.accept(array[i]); + } + else { + for (int i = 0; i < offset && !downstream.cancellationRequested(); i++) + downstream.accept(array[i]); + } downstream.end(); array = null; } @@ -518,7 +659,7 @@ final class SortedOps { /** * {@link Sink} for implementing sort on double streams. */ - private static final class DoubleSortingSink extends Sink.ChainedDouble { + private static final class DoubleSortingSink extends AbstractDoubleSortingSink { private SpinedBuffer.OfDouble b; DoubleSortingSink(Sink sink) { @@ -537,8 +678,16 @@ final class SortedOps { double[] doubles = b.asPrimitiveArray(); Arrays.sort(doubles); downstream.begin(doubles.length); - for (double aDouble : doubles) - downstream.accept(aDouble); + if (!cancellationWasRequested) { + for (double aDouble : doubles) + downstream.accept(aDouble); + } + else { + for (double aDouble : doubles) { + if (downstream.cancellationRequested()) break; + downstream.accept(aDouble); + } + } downstream.end(); } diff --git a/test/java/util/stream/test/org/openjdk/tests/java/util/stream/SortedOpTest.java b/test/java/util/stream/test/org/openjdk/tests/java/util/stream/SortedOpTest.java index 960e614fdc429ce3b50b0ba281e3a70c325929ff..3ca690e9096d3bdb37fc610cb7ff538bf1c53110 100644 --- a/test/java/util/stream/test/org/openjdk/tests/java/util/stream/SortedOpTest.java +++ b/test/java/util/stream/test/org/openjdk/tests/java/util/stream/SortedOpTest.java @@ -26,6 +26,9 @@ import org.testng.annotations.Test; import java.util.*; import java.util.Spliterators; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.*; @@ -122,24 +125,33 @@ public class SortedOpTest extends OpTestCase { @Test(groups = { "serialization-hostile" }) public void testSequentialShortCircuitTerminal() { - // The sorted op for sequential evaluation will buffer all elements when accepting - // then at the end sort those elements and push those elements downstream + // The sorted op for sequential evaluation will buffer all elements when + // accepting then at the end sort those elements and push those elements + // downstream + // A peek operation is added in-between the sorted() and terminal + // operation that counts the number of calls to its consumer and + // asserts that the number of calls is at most the required quantity List l = Arrays.asList(5, 4, 3, 2, 1); + Function> knownSize = i -> assertNCallsOnly( + l.stream().sorted(), Stream::peek, i); + Function> unknownSize = i -> assertNCallsOnly + (unknownSizeStream(l).sorted(), Stream::peek, i); + // Find - assertEquals(l.stream().sorted().findFirst(), Optional.of(1)); - assertEquals(l.stream().sorted().findAny(), Optional.of(1)); - assertEquals(unknownSizeStream(l).sorted().findFirst(), Optional.of(1)); - assertEquals(unknownSizeStream(l).sorted().findAny(), Optional.of(1)); + assertEquals(knownSize.apply(1).findFirst(), Optional.of(1)); + assertEquals(knownSize.apply(1).findAny(), Optional.of(1)); + assertEquals(unknownSize.apply(1).findFirst(), Optional.of(1)); + assertEquals(unknownSize.apply(1).findAny(), Optional.of(1)); // Match - assertEquals(l.stream().sorted().anyMatch(i -> i == 2), true); - assertEquals(l.stream().sorted().noneMatch(i -> i == 2), false); - assertEquals(l.stream().sorted().allMatch(i -> i == 2), false); - assertEquals(unknownSizeStream(l).sorted().anyMatch(i -> i == 2), true); - assertEquals(unknownSizeStream(l).sorted().noneMatch(i -> i == 2), false); - assertEquals(unknownSizeStream(l).sorted().allMatch(i -> i == 2), false); + assertEquals(knownSize.apply(2).anyMatch(i -> i == 2), true); + assertEquals(knownSize.apply(2).noneMatch(i -> i == 2), false); + assertEquals(knownSize.apply(2).allMatch(i -> i == 2), false); + assertEquals(unknownSize.apply(2).anyMatch(i -> i == 2), true); + assertEquals(unknownSize.apply(2).noneMatch(i -> i == 2), false); + assertEquals(unknownSize.apply(2).allMatch(i -> i == 2), false); } private Stream unknownSizeStream(List l) { @@ -199,19 +211,24 @@ public class SortedOpTest extends OpTestCase { public void testIntSequentialShortCircuitTerminal() { int[] a = new int[]{5, 4, 3, 2, 1}; + Function knownSize = i -> assertNCallsOnly( + Arrays.stream(a).sorted(), (s, c) -> s.peek(c::accept), i); + Function unknownSize = i -> assertNCallsOnly + (unknownSizeIntStream(a).sorted(), (s, c) -> s.peek(c::accept), i); + // Find - assertEquals(Arrays.stream(a).sorted().findFirst(), OptionalInt.of(1)); - assertEquals(Arrays.stream(a).sorted().findAny(), OptionalInt.of(1)); - assertEquals(unknownSizeIntStream(a).sorted().findFirst(), OptionalInt.of(1)); - assertEquals(unknownSizeIntStream(a).sorted().findAny(), OptionalInt.of(1)); + assertEquals(knownSize.apply(1).findFirst(), OptionalInt.of(1)); + assertEquals(knownSize.apply(1).findAny(), OptionalInt.of(1)); + assertEquals(unknownSize.apply(1).findFirst(), OptionalInt.of(1)); + assertEquals(unknownSize.apply(1).findAny(), OptionalInt.of(1)); // Match - assertEquals(Arrays.stream(a).sorted().anyMatch(i -> i == 2), true); - assertEquals(Arrays.stream(a).sorted().noneMatch(i -> i == 2), false); - assertEquals(Arrays.stream(a).sorted().allMatch(i -> i == 2), false); - assertEquals(unknownSizeIntStream(a).sorted().anyMatch(i -> i == 2), true); - assertEquals(unknownSizeIntStream(a).sorted().noneMatch(i -> i == 2), false); - assertEquals(unknownSizeIntStream(a).sorted().allMatch(i -> i == 2), false); + assertEquals(knownSize.apply(2).anyMatch(i -> i == 2), true); + assertEquals(knownSize.apply(2).noneMatch(i -> i == 2), false); + assertEquals(knownSize.apply(2).allMatch(i -> i == 2), false); + assertEquals(unknownSize.apply(2).anyMatch(i -> i == 2), true); + assertEquals(unknownSize.apply(2).noneMatch(i -> i == 2), false); + assertEquals(unknownSize.apply(2).allMatch(i -> i == 2), false); } private IntStream unknownSizeIntStream(int[] a) { @@ -242,19 +259,24 @@ public class SortedOpTest extends OpTestCase { public void testLongSequentialShortCircuitTerminal() { long[] a = new long[]{5, 4, 3, 2, 1}; + Function knownSize = i -> assertNCallsOnly( + Arrays.stream(a).sorted(), (s, c) -> s.peek(c::accept), i); + Function unknownSize = i -> assertNCallsOnly + (unknownSizeLongStream(a).sorted(), (s, c) -> s.peek(c::accept), i); + // Find - assertEquals(Arrays.stream(a).sorted().findFirst(), OptionalLong.of(1)); - assertEquals(Arrays.stream(a).sorted().findAny(), OptionalLong.of(1)); - assertEquals(unknownSizeLongStream(a).sorted().findFirst(), OptionalLong.of(1)); - assertEquals(unknownSizeLongStream(a).sorted().findAny(), OptionalLong.of(1)); + assertEquals(knownSize.apply(1).findFirst(), OptionalLong.of(1)); + assertEquals(knownSize.apply(1).findAny(), OptionalLong.of(1)); + assertEquals(unknownSize.apply(1).findFirst(), OptionalLong.of(1)); + assertEquals(unknownSize.apply(1).findAny(), OptionalLong.of(1)); // Match - assertEquals(Arrays.stream(a).sorted().anyMatch(i -> i == 2), true); - assertEquals(Arrays.stream(a).sorted().noneMatch(i -> i == 2), false); - assertEquals(Arrays.stream(a).sorted().allMatch(i -> i == 2), false); - assertEquals(unknownSizeLongStream(a).sorted().anyMatch(i -> i == 2), true); - assertEquals(unknownSizeLongStream(a).sorted().noneMatch(i -> i == 2), false); - assertEquals(unknownSizeLongStream(a).sorted().allMatch(i -> i == 2), false); + assertEquals(knownSize.apply(2).anyMatch(i -> i == 2), true); + assertEquals(knownSize.apply(2).noneMatch(i -> i == 2), false); + assertEquals(knownSize.apply(2).allMatch(i -> i == 2), false); + assertEquals(unknownSize.apply(2).anyMatch(i -> i == 2), true); + assertEquals(unknownSize.apply(2).noneMatch(i -> i == 2), false); + assertEquals(unknownSize.apply(2).allMatch(i -> i == 2), false); } private LongStream unknownSizeLongStream(long[] a) { @@ -285,19 +307,24 @@ public class SortedOpTest extends OpTestCase { public void testDoubleSequentialShortCircuitTerminal() { double[] a = new double[]{5.0, 4.0, 3.0, 2.0, 1.0}; + Function knownSize = i -> assertNCallsOnly( + Arrays.stream(a).sorted(), (s, c) -> s.peek(c::accept), i); + Function unknownSize = i -> assertNCallsOnly + (unknownSizeDoubleStream(a).sorted(), (s, c) -> s.peek(c::accept), i); + // Find - assertEquals(Arrays.stream(a).sorted().findFirst(), OptionalDouble.of(1)); - assertEquals(Arrays.stream(a).sorted().findAny(), OptionalDouble.of(1)); - assertEquals(unknownSizeDoubleStream(a).sorted().findFirst(), OptionalDouble.of(1)); - assertEquals(unknownSizeDoubleStream(a).sorted().findAny(), OptionalDouble.of(1)); + assertEquals(knownSize.apply(1).findFirst(), OptionalDouble.of(1)); + assertEquals(knownSize.apply(1).findAny(), OptionalDouble.of(1)); + assertEquals(unknownSize.apply(1).findFirst(), OptionalDouble.of(1)); + assertEquals(unknownSize.apply(1).findAny(), OptionalDouble.of(1)); // Match - assertEquals(Arrays.stream(a).sorted().anyMatch(i -> i == 2.0), true); - assertEquals(Arrays.stream(a).sorted().noneMatch(i -> i == 2.0), false); - assertEquals(Arrays.stream(a).sorted().allMatch(i -> i == 2.0), false); - assertEquals(unknownSizeDoubleStream(a).sorted().anyMatch(i -> i == 2.0), true); - assertEquals(unknownSizeDoubleStream(a).sorted().noneMatch(i -> i == 2.0), false); - assertEquals(unknownSizeDoubleStream(a).sorted().allMatch(i -> i == 2.0), false); + assertEquals(knownSize.apply(2).anyMatch(i -> i == 2.0), true); + assertEquals(knownSize.apply(2).noneMatch(i -> i == 2.0), false); + assertEquals(knownSize.apply(2).allMatch(i -> i == 2.0), false); + assertEquals(unknownSize.apply(2).anyMatch(i -> i == 2.0), true); + assertEquals(unknownSize.apply(2).noneMatch(i -> i == 2.0), false); + assertEquals(unknownSize.apply(2).allMatch(i -> i == 2.0), false); } private DoubleStream unknownSizeDoubleStream(double[] a) { @@ -321,4 +348,14 @@ public class SortedOpTest extends OpTestCase { assertSorted(result); assertContentsUnordered(data, result); } + + /** + * Interpose a consumer that asserts it is called at most N times. + */ + , R> S assertNCallsOnly(S s, BiFunction, S> pf, int n) { + AtomicInteger boxedInt = new AtomicInteger(); + return pf.apply(s, i -> { + assertFalse(boxedInt.incrementAndGet() > n, "Intermediate op called more than " + n + " time(s)"); + }); + } }