From 47a68adc7fa5da7743cb51fe69f81264e908e630 Mon Sep 17 00:00:00 2001 From: Markus Holzemer Date: Mon, 16 Jun 2014 09:41:46 +0200 Subject: [PATCH] [FLINK-836] Rework of the cached match driver --- .../compiler/dag/OptimizerNode.java | 10 - .../HashJoinBuildFirstProperties.java | 2 +- .../HashJoinBuildSecondProperties.java | 2 +- .../stratosphere/compiler/plan/Channel.java | 4 - .../stratosphere/compiler/plan/PlanNode.java | 4 - .../services/memorymanager/MemoryManager.java | 8 +- .../spi/DefaultMemoryManager.java | 4 +- ...BuildFirstReOpenableHashMatchIterator.java | 10 +- .../hash/BuildSecondHashMatchIterator.java | 6 +- ...uildSecondReOpenableHashMatchIterator.java | 71 +++ .../sort/AsynchronousPartialSorter.java | 13 +- .../runtime/sort/UnilateralSortMerger.java | 2 +- .../AbstractCachedBuildSideMatchDriver.java | 131 ++--- .../pact/runtime/task/RegularPactTask.java | 10 + .../sort/AsynchonousPartialSorterITCase.java | 6 +- .../runtime/task/CachedMatchTaskTest.java | 491 ++++++++++++++++++ .../runtime/test/util/DriverTestBase.java | 33 ++ 17 files changed, 683 insertions(+), 124 deletions(-) create mode 100644 stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildSecondReOpenableHashMatchIterator.java create mode 100644 stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/task/CachedMatchTaskTest.java diff --git a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/dag/OptimizerNode.java b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/dag/OptimizerNode.java index 6f295a5a945..ec9bd69be17 100644 --- a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/dag/OptimizerNode.java +++ b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/dag/OptimizerNode.java @@ -92,8 +92,6 @@ public abstract class OptimizerNode implements Visitable, Estimat protected boolean onDynamicPath; - protected boolean insideIteration; - protected List cachedPlans; // cache candidates, because the may be accessed repeatedly protected int[][] remappedKeys; @@ -501,14 +499,6 @@ public abstract class OptimizerNode implements Visitable, Estimat } } - public boolean isInsideIteration() { - return insideIteration; - } - - public void setInsideIteration(boolean insideIteration) { - this.insideIteration = insideIteration; - } - /** * Checks whether this node has branching output. A node's output is branched, if it has more * than one output connection. diff --git a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/operators/HashJoinBuildFirstProperties.java b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/operators/HashJoinBuildFirstProperties.java index 661b316f0b9..4f694c5a8f9 100644 --- a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/operators/HashJoinBuildFirstProperties.java +++ b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/operators/HashJoinBuildFirstProperties.java @@ -56,7 +56,7 @@ public class HashJoinBuildFirstProperties extends AbstractJoinDescriptor { public DualInputPlanNode instantiate(Channel in1, Channel in2, TwoInputNode node) { DriverStrategy strategy; - if(!in1.isOnDynamicPath() && in1.isInsideIteration() && in2.isInsideIteration()) { + if(!in1.isOnDynamicPath() && in2.isOnDynamicPath()) { strategy = DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED; } else { diff --git a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/operators/HashJoinBuildSecondProperties.java b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/operators/HashJoinBuildSecondProperties.java index e085588c4ad..6bea65a3723 100644 --- a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/operators/HashJoinBuildSecondProperties.java +++ b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/operators/HashJoinBuildSecondProperties.java @@ -53,7 +53,7 @@ public final class HashJoinBuildSecondProperties extends AbstractJoinDescriptor public DualInputPlanNode instantiate(Channel in1, Channel in2, TwoInputNode node) { DriverStrategy strategy; - if(!in2.isOnDynamicPath() && in1.isInsideIteration() && in2.isInsideIteration()) { + if(!in2.isOnDynamicPath() && in1.isOnDynamicPath()) { strategy = DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED; } else { diff --git a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plan/Channel.java b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plan/Channel.java index d83da8226a5..6f9418fde85 100644 --- a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plan/Channel.java +++ b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plan/Channel.java @@ -306,10 +306,6 @@ public class Channel implements EstimateProvider, Cloneable, DumpableConnection< return this.source.isOnDynamicPath(); } - public boolean isInsideIteration() { - return this.source.isInsideIteration(); - } - public int getCostWeight() { return this.source.getCostWeight(); } diff --git a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plan/PlanNode.java b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plan/PlanNode.java index da24b5f36e2..539006cf961 100644 --- a/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plan/PlanNode.java +++ b/stratosphere-compiler/src/main/java/eu/stratosphere/compiler/plan/PlanNode.java @@ -423,10 +423,6 @@ public abstract class PlanNode implements Visitable, DumpableNode 1) { + throw new IllegalArgumentException("The fraction of memory to allocate must within (0, 1]."); } return (int)(this.totalNumPages * fraction / this.numberOfSlots); diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildFirstReOpenableHashMatchIterator.java b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildFirstReOpenableHashMatchIterator.java index 8c2b9cafe03..7898c41f457 100644 --- a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildFirstReOpenableHashMatchIterator.java +++ b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildFirstReOpenableHashMatchIterator.java @@ -37,15 +37,19 @@ public class BuildFirstReOpenableHashMatchIterator extends BuildFirst TypeSerializer serializer1, TypeComparator comparator1, TypeSerializer serializer2, TypeComparator comparator2, TypePairComparator pairComparator, - MemoryManager memManager, IOManager ioManager, - AbstractInvokable ownerTask, double memoryFraction) - throws MemoryAllocationException { + MemoryManager memManager, + IOManager ioManager, + AbstractInvokable ownerTask, + double memoryFraction) + throws MemoryAllocationException + { super(firstInput, secondInput, serializer1, comparator1, serializer2, comparator2, pairComparator, memManager, ioManager, ownerTask, memoryFraction); reopenHashTable = (ReOpenableMutableHashTable) hashJoin; } + @Override public MutableHashTable getHashJoin(TypeSerializer buildSideSerializer, TypeComparator buildSideComparator, TypeSerializer probeSideSerializer, TypeComparator probeSideComparator, TypePairComparator pairComparator, diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildSecondHashMatchIterator.java b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildSecondHashMatchIterator.java index 732d256fdf0..9f3fd97ae28 100644 --- a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildSecondHashMatchIterator.java +++ b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildSecondHashMatchIterator.java @@ -34,9 +34,9 @@ import eu.stratosphere.util.MutableObjectIterator; * An implementation of the {@link eu.stratosphere.pact.runtime.task.util.JoinTaskIterator} that uses a hybrid-hash-join * internally to match the records with equal key. The build side of the hash is the second input of the match. */ -public final class BuildSecondHashMatchIterator implements JoinTaskIterator { +public class BuildSecondHashMatchIterator implements JoinTaskIterator { - private final MutableHashTable hashJoin; + protected final MutableHashTable hashJoin; private final V2 nextBuildSideObject; @@ -44,7 +44,7 @@ public final class BuildSecondHashMatchIterator implements JoinTaskIt private final V1 probeCopy; - private final TypeSerializer probeSideSerializer; + protected final TypeSerializer probeSideSerializer; private final MemoryManager memManager; diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildSecondReOpenableHashMatchIterator.java b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildSecondReOpenableHashMatchIterator.java new file mode 100644 index 00000000000..597ae73d0f8 --- /dev/null +++ b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/hash/BuildSecondReOpenableHashMatchIterator.java @@ -0,0 +1,71 @@ +/*********************************************************************************************************************** + * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu) + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + **********************************************************************************************************************/ + +package eu.stratosphere.pact.runtime.hash; + +import java.io.IOException; +import java.util.List; + +import eu.stratosphere.api.common.typeutils.TypeComparator; +import eu.stratosphere.api.common.typeutils.TypePairComparator; +import eu.stratosphere.api.common.typeutils.TypeSerializer; +import eu.stratosphere.core.memory.MemorySegment; +import eu.stratosphere.nephele.services.iomanager.IOManager; +import eu.stratosphere.nephele.services.memorymanager.MemoryAllocationException; +import eu.stratosphere.nephele.services.memorymanager.MemoryManager; +import eu.stratosphere.nephele.template.AbstractInvokable; +import eu.stratosphere.util.MutableObjectIterator; + +public class BuildSecondReOpenableHashMatchIterator extends BuildSecondHashMatchIterator { + + + private final ReOpenableMutableHashTable reopenHashTable; + + public BuildSecondReOpenableHashMatchIterator( + MutableObjectIterator firstInput, + MutableObjectIterator secondInput, + TypeSerializer serializer1, TypeComparator comparator1, + TypeSerializer serializer2, TypeComparator comparator2, + TypePairComparator pairComparator, + MemoryManager memManager, + IOManager ioManager, + AbstractInvokable ownerTask, + double memoryFraction) + throws MemoryAllocationException + { + super(firstInput, secondInput, serializer1, comparator1, serializer2, + comparator2, pairComparator, memManager, ioManager, ownerTask, memoryFraction); + reopenHashTable = (ReOpenableMutableHashTable) hashJoin; + } + + @Override + public MutableHashTable getHashJoin(TypeSerializer buildSideSerializer, TypeComparator buildSideComparator, + TypeSerializer probeSideSerializer, TypeComparator probeSideComparator, + TypePairComparator pairComparator, + MemoryManager memManager, IOManager ioManager, AbstractInvokable ownerTask, double memoryFraction) + throws MemoryAllocationException + { + final int numPages = memManager.computeNumberOfPages(memoryFraction); + final List memorySegments = memManager.allocatePages(ownerTask, numPages); + return new ReOpenableMutableHashTable(buildSideSerializer, probeSideSerializer, buildSideComparator, probeSideComparator, pairComparator, memorySegments, ioManager); + } + + /** + * Set new input for probe side + * @throws IOException + */ + public void reopenProbe(MutableObjectIterator probeInput) throws IOException { + reopenHashTable.reopenProbe(probeInput); + } + +} diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/sort/AsynchronousPartialSorter.java b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/sort/AsynchronousPartialSorter.java index 35377cfeb10..f87f1f84477 100644 --- a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/sort/AsynchronousPartialSorter.java +++ b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/sort/AsynchronousPartialSorter.java @@ -34,8 +34,6 @@ import eu.stratosphere.util.MutableObjectIterator; */ public class AsynchronousPartialSorter extends UnilateralSortMerger { - private static final int MAX_MEM_PER_PARTIAL_SORT = 64 * 1024 * 0124; - private BufferQueueIterator bufferIterator; // ------------------------------------------------------------------------ @@ -62,11 +60,7 @@ public class AsynchronousPartialSorter extends UnilateralSortMerger { double memoryFraction) throws IOException, MemoryAllocationException { - super(memoryManager, null, input, parentTask, serializerFactory, comparator, memoryFraction, - memoryManager.computeNumberOfPages(memoryFraction) < 2 * MIN_NUM_SORT_MEM_SEGMENTS ? 1 : - Math.max((int) Math.ceil(((double) memoryManager.computeMemorySize(memoryFraction)) / - MAX_MEM_PER_PARTIAL_SORT), 2), - 2, 0.0f, true); + super(memoryManager, null, input, parentTask, serializerFactory, comparator, memoryFraction, 1, 2, 0.0f, true); } @@ -101,11 +95,6 @@ public class AsynchronousPartialSorter extends UnilateralSortMerger { // ------------------------------------------------------------------------ - /** - * This class implements an iterator over values from a {@link eu.stratosphere.pact.runtime.sort.BufferSortable}. - * The iterator returns the values of a given - * interval. - */ private final class BufferQueueIterator implements MutableObjectIterator { private final CircularQueues queues; diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/sort/UnilateralSortMerger.java b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/sort/UnilateralSortMerger.java index 6905b8524bc..9109e9acda1 100644 --- a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/sort/UnilateralSortMerger.java +++ b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/sort/UnilateralSortMerger.java @@ -239,7 +239,7 @@ public class UnilateralSortMerger implements Sorter { * @param maxNumFileHandles The maximum number of files to be merged at once. * @param startSpillingFraction The faction of the buffers that have to be filled before the spilling thread * actually begins spilling data to disk. - * @param noSpilling When set to true, no memory will be allocated for writing and no spilling thread + * @param noSpillingMemory When set to true, no memory will be allocated for writing and no spilling thread * will be spawned. * * @throws IOException Thrown, if an error occurs initializing the resources for external sorting. diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/AbstractCachedBuildSideMatchDriver.java b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/AbstractCachedBuildSideMatchDriver.java index 1d3c55dd9a3..1c8c4275f6b 100644 --- a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/AbstractCachedBuildSideMatchDriver.java +++ b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/AbstractCachedBuildSideMatchDriver.java @@ -13,26 +13,20 @@ package eu.stratosphere.pact.runtime.task; -import java.util.List; - import eu.stratosphere.api.common.functions.GenericJoiner; import eu.stratosphere.api.common.typeutils.TypeComparator; import eu.stratosphere.api.common.typeutils.TypePairComparatorFactory; import eu.stratosphere.api.common.typeutils.TypeSerializer; -import eu.stratosphere.core.memory.MemorySegment; -import eu.stratosphere.pact.runtime.hash.MutableHashTable; +import eu.stratosphere.pact.runtime.hash.BuildFirstReOpenableHashMatchIterator; +import eu.stratosphere.pact.runtime.hash.BuildSecondReOpenableHashMatchIterator; +import eu.stratosphere.pact.runtime.task.util.JoinTaskIterator; import eu.stratosphere.pact.runtime.task.util.TaskConfig; -import eu.stratosphere.pact.runtime.util.EmptyMutableObjectIterator; import eu.stratosphere.util.Collector; import eu.stratosphere.util.MutableObjectIterator; public abstract class AbstractCachedBuildSideMatchDriver extends MatchDriver implements ResettablePactDriver, OT> { - - /** - * We keep it without generic parameters, because they vary depending on which input is the build side. - */ - protected volatile MutableHashTable hashJoin; + private volatile JoinTaskIterator matchIterator; private final int buildSideIndex; @@ -67,23 +61,39 @@ public abstract class AbstractCachedBuildSideMatchDriver extends M TypePairComparatorFactory pairComparatorFactory = this.taskContext.getTaskConfig().getPairComparatorFactory(this.taskContext.getUserCodeClassLoader()); - int numMemoryPages = this.taskContext.getMemoryManager().computeNumberOfPages(config.getRelativeMemoryDriver()); - List memSegments = this.taskContext.getMemoryManager().allocatePages( - this.taskContext.getOwningNepheleTask(), numMemoryPages); + double availableMemory = config.getRelativeMemoryDriver(); if (buildSideIndex == 0 && probeSideIndex == 1) { - MutableHashTable hashJoin = new MutableHashTable(serializer1, serializer2, comparator1, comparator2, - pairComparatorFactory.createComparator21(comparator1, comparator2), memSegments, this.taskContext.getIOManager()); - this.hashJoin = hashJoin; - hashJoin.open(input1, EmptyMutableObjectIterator.get()); + + matchIterator = + new BuildFirstReOpenableHashMatchIterator(input1, input2, + serializer1, comparator1, + serializer2, comparator2, + pairComparatorFactory.createComparator21(comparator1, comparator2), + this.taskContext.getMemoryManager(), + this.taskContext.getIOManager(), + this.taskContext.getOwningNepheleTask(), + availableMemory + ); + } else if (buildSideIndex == 1 && probeSideIndex == 0) { - MutableHashTable hashJoin = new MutableHashTable(serializer2, serializer1, comparator2, comparator1, - pairComparatorFactory.createComparator12(comparator1, comparator2), memSegments, this.taskContext.getIOManager()); - this.hashJoin = hashJoin; - hashJoin.open(input2, EmptyMutableObjectIterator.get()); + + matchIterator = + new BuildSecondReOpenableHashMatchIterator(input1, input2, + serializer1, comparator1, + serializer2, comparator2, + pairComparatorFactory.createComparator12(comparator1, comparator2), + this.taskContext.getMemoryManager(), + this.taskContext.getIOManager(), + this.taskContext.getOwningNepheleTask(), + availableMemory + ); + } else { throw new Exception("Error: Inconcistent setup for repeatable hash join driver."); } + + this.matchIterator.open(); } @Override @@ -98,63 +108,17 @@ public abstract class AbstractCachedBuildSideMatchDriver extends M final Collector collector = this.taskContext.getOutputCollector(); if (buildSideIndex == 0) { - final TypeSerializer buildSideSerializer = taskContext. getInputSerializer(0).getSerializer(); - final TypeSerializer probeSideSerializer = taskContext. getInputSerializer(1).getSerializer(); - IT1 buildSideRecordFirst; - IT1 buildSideRecordOther; - IT2 probeSideRecord; - IT2 probeSideRecordCopy; - final IT1 buildSideRecordFirstReuse = buildSideSerializer.createInstance(); - final IT1 buildSideRecordOtherReuse = buildSideSerializer.createInstance(); - final IT2 probeSideRecordReuse = probeSideSerializer.createInstance(); - final IT2 probeSideRecordCopyReuse = probeSideSerializer.createInstance(); + final BuildFirstReOpenableHashMatchIterator matchIterator = (BuildFirstReOpenableHashMatchIterator) this.matchIterator; - @SuppressWarnings("unchecked") - final MutableHashTable join = (MutableHashTable) this.hashJoin; + while (this.running && matchIterator != null && matchIterator.callWithNextKey(matchStub, collector)); - final MutableObjectIterator probeSideInput = taskContext.getInput(1); - - while (this.running && ((probeSideRecord = probeSideInput.next(probeSideRecordReuse)) != null)) { - final MutableHashTable.HashBucketIterator bucket = join.getMatchesFor(probeSideRecord); - - if ((buildSideRecordFirst = bucket.next(buildSideRecordFirstReuse)) != null) { - while ((buildSideRecordOther = bucket.next(buildSideRecordOtherReuse)) != null) { - probeSideRecordCopy = probeSideSerializer.copy(probeSideRecord, probeSideRecordCopyReuse); - matchStub.join(buildSideRecordOther, probeSideRecordCopy, collector); - } - matchStub.join(buildSideRecordFirst, probeSideRecord, collector); - } - } } else if (buildSideIndex == 1) { - final TypeSerializer buildSideSerializer = taskContext.getInputSerializer(1).getSerializer(); - final TypeSerializer probeSideSerializer = taskContext.getInputSerializer(0).getSerializer(); - IT2 buildSideRecordFirst; - IT2 buildSideRecordOther; - IT1 probeSideRecord; - IT1 probeSideRecordCopy; - final IT2 buildSideRecordFirstReuse = buildSideSerializer.createInstance(); - final IT2 buildSideRecordOtherReuse = buildSideSerializer.createInstance(); - final IT1 probeSideRecordReuse = probeSideSerializer.createInstance(); - final IT1 probeSideRecordCopyReuse = probeSideSerializer.createInstance(); - - @SuppressWarnings("unchecked") - final MutableHashTable join = (MutableHashTable) this.hashJoin; + final BuildSecondReOpenableHashMatchIterator matchIterator = (BuildSecondReOpenableHashMatchIterator) this.matchIterator; - final MutableObjectIterator probeSideInput = taskContext.getInput(0); + while (this.running && matchIterator != null && matchIterator.callWithNextKey(matchStub, collector)); - while (this.running && ((probeSideRecord = probeSideInput.next(probeSideRecordReuse)) != null)) { - final MutableHashTable.HashBucketIterator bucket = join.getMatchesFor(probeSideRecord); - - if ((buildSideRecordFirst = bucket.next(buildSideRecordFirstReuse)) != null) { - while ((buildSideRecordOther = bucket.next(buildSideRecordOtherReuse)) != null) { - probeSideRecordCopy = probeSideSerializer.copy(probeSideRecord, probeSideRecordCopyReuse); - matchStub.join(probeSideRecordCopy, buildSideRecordOther, collector); - } - matchStub.join(probeSideRecord, buildSideRecordFirst, collector); - } - } } else { throw new Exception(); } @@ -164,21 +128,34 @@ public abstract class AbstractCachedBuildSideMatchDriver extends M public void cleanup() throws Exception {} @Override - public void reset() throws Exception {} + public void reset() throws Exception { + + MutableObjectIterator input1 = this.taskContext.getInput(0); + MutableObjectIterator input2 = this.taskContext.getInput(1); + + if (buildSideIndex == 0 && probeSideIndex == 1) { + final BuildFirstReOpenableHashMatchIterator matchIterator = (BuildFirstReOpenableHashMatchIterator) this.matchIterator; + matchIterator.reopenProbe(input2); + } + else { + final BuildSecondReOpenableHashMatchIterator matchIterator = (BuildSecondReOpenableHashMatchIterator) this.matchIterator; + matchIterator.reopenProbe(input1); + } + } @Override public void teardown() { - MutableHashTable ht = this.hashJoin; - if (ht != null) { - ht.close(); + this.running = false; + if (this.matchIterator != null) { + this.matchIterator.close(); } } @Override public void cancel() { this.running = false; - if (this.hashJoin != null) { - this.hashJoin.close(); + if (this.matchIterator != null) { + this.matchIterator.abort(); } } } diff --git a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/RegularPactTask.java b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/RegularPactTask.java index 31405254fd8..3e4a1fd7fe2 100644 --- a/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/RegularPactTask.java +++ b/stratosphere-runtime/src/main/java/eu/stratosphere/pact/runtime/task/RegularPactTask.java @@ -531,6 +531,16 @@ public class RegularPactTask extends AbstractInvokable i } catch (Throwable t) {} } + + // if resettable driver invoke treardown + if (this.driver instanceof ResettablePactDriver) { + final ResettablePactDriver resDriver = (ResettablePactDriver) this.driver; + try { + resDriver.teardown(); + } catch (Throwable t) { + throw new Exception("Error while shutting down an iterative operator: " + t.getMessage(), t); + } + } RegularPactTask.cancelChainedTasks(this.chainedTasks); diff --git a/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/sort/AsynchonousPartialSorterITCase.java b/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/sort/AsynchonousPartialSorterITCase.java index f191075c6e8..155bf285320 100644 --- a/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/sort/AsynchonousPartialSorterITCase.java +++ b/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/sort/AsynchonousPartialSorterITCase.java @@ -128,7 +128,7 @@ public class AsynchonousPartialSorterITCase { // merge iterator LOG.debug("Initializing sortmerger..."); Sorter sorter = new AsynchronousPartialSorter(this.memoryManager, source, - this.parentTask, this.serializer, this.comparator, 1.0); + this.parentTask, this.serializer, this.comparator, 0.2); runPartialSorter(sorter, NUM_RECORDS, 2); } @@ -151,9 +151,9 @@ public class AsynchonousPartialSorterITCase { // merge iterator LOG.debug("Initializing sortmerger..."); Sorter sorter = new AsynchronousPartialSorter(this.memoryManager, source, - this.parentTask, this.serializer, this.comparator, 1.0); + this.parentTask, this.serializer, this.comparator, 0.15); - runPartialSorter(sorter, NUM_RECORDS, 28); + runPartialSorter(sorter, NUM_RECORDS, 27); } catch (Exception t) { t.printStackTrace(); diff --git a/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/task/CachedMatchTaskTest.java b/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/task/CachedMatchTaskTest.java new file mode 100644 index 00000000000..ff560df380d --- /dev/null +++ b/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/task/CachedMatchTaskTest.java @@ -0,0 +1,491 @@ +/*********************************************************************************************************************** + * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu) + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + **********************************************************************************************************************/ + +package eu.stratosphere.pact.runtime.task; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.junit.Assert; +import org.junit.Test; + +import eu.stratosphere.api.common.functions.GenericJoiner; +import eu.stratosphere.api.java.record.functions.JoinFunction; +import eu.stratosphere.api.java.typeutils.runtime.record.RecordComparator; +import eu.stratosphere.api.java.typeutils.runtime.record.RecordPairComparatorFactory; +import eu.stratosphere.pact.runtime.test.util.DelayingInfinitiveInputIterator; +import eu.stratosphere.pact.runtime.test.util.DriverTestBase; +import eu.stratosphere.pact.runtime.test.util.ExpectedTestException; +import eu.stratosphere.pact.runtime.test.util.NirvanaOutputList; +import eu.stratosphere.pact.runtime.test.util.TaskCancelThread; +import eu.stratosphere.pact.runtime.test.util.UniformRecordGenerator; +import eu.stratosphere.types.IntValue; +import eu.stratosphere.types.Key; +import eu.stratosphere.types.Record; +import eu.stratosphere.util.Collector; + +public class CachedMatchTaskTest extends DriverTestBase> +{ + private static final long HASH_MEM = 6*1024*1024; + + private static final long SORT_MEM = 3*1024*1024; + + @SuppressWarnings("unchecked") + private final RecordComparator comparator1 = new RecordComparator( + new int[]{0}, (Class>[])new Class[]{ IntValue.class }); + + @SuppressWarnings("unchecked") + private final RecordComparator comparator2 = new RecordComparator( + new int[]{0}, (Class>[])new Class[]{ IntValue.class }); + + private final List outList = new ArrayList(); + + + public CachedMatchTaskTest() { + super(HASH_MEM, 2, SORT_MEM); + } + + @Test + public void testHash1MatchTask() { + int keyCnt1 = 20; + int valCnt1 = 1; + + int keyCnt2 = 10; + int valCnt2 = 2; + + addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false)); + addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false)); + addInputComparator(this.comparator1); + addInputComparator(this.comparator2); + getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); + setOutput(this.outList); + getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED); + getTaskConfig().setRelativeMemoryDriver(1.0f); + + BuildFirstCachedMatchDriver testTask = new BuildFirstCachedMatchDriver(); + + try { + testResettableDriver(testTask, MockMatchStub.class, 3); + } catch (Exception e) { + e.printStackTrace(); + Assert.fail("Test caused an exception."); + } + + final int expCnt = valCnt1*valCnt2*Math.min(keyCnt1, keyCnt2); + Assert.assertEquals("Wrong result set size.", expCnt, this.outList.size()); + this.outList.clear(); + } + + @Test + public void testHash2MatchTask() { + int keyCnt1 = 20; + int valCnt1 = 1; + + int keyCnt2 = 20; + int valCnt2 = 1; + + addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false)); + addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false)); + addInputComparator(this.comparator1); + addInputComparator(this.comparator2); + getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); + setOutput(this.outList); + getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED); + getTaskConfig().setRelativeMemoryDriver(1.0f); + + BuildSecondCachedMatchDriver testTask = new BuildSecondCachedMatchDriver(); + + try { + testResettableDriver(testTask, MockMatchStub.class, 3); + } catch (Exception e) { + e.printStackTrace(); + Assert.fail("Test caused an exception."); + } + + final int expCnt = valCnt1*valCnt2*Math.min(keyCnt1, keyCnt2); + Assert.assertEquals("Wrong result set size.", expCnt, this.outList.size()); + this.outList.clear(); + } + + @Test + public void testHash3MatchTask() { + int keyCnt1 = 20; + int valCnt1 = 1; + + int keyCnt2 = 20; + int valCnt2 = 20; + + addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false)); + addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false)); + addInputComparator(this.comparator1); + addInputComparator(this.comparator2); + getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); + setOutput(this.outList); + getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED); + getTaskConfig().setRelativeMemoryDriver(1.0f); + + BuildFirstCachedMatchDriver testTask = new BuildFirstCachedMatchDriver(); + + try { + testResettableDriver(testTask, MockMatchStub.class, 3); + } catch (Exception e) { + e.printStackTrace(); + Assert.fail("Test caused an exception."); + } + + final int expCnt = valCnt1*valCnt2*Math.min(keyCnt1, keyCnt2); + Assert.assertEquals("Wrong result set size.", expCnt, this.outList.size()); + this.outList.clear(); + } + + @Test + public void testHash4MatchTask() { + int keyCnt1 = 20; + int valCnt1 = 20; + + int keyCnt2 = 20; + int valCnt2 = 1; + + addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false)); + addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false)); + addInputComparator(this.comparator1); + addInputComparator(this.comparator2); + getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); + setOutput(this.outList); + getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED); + getTaskConfig().setRelativeMemoryDriver(1.0f); + + BuildSecondCachedMatchDriver testTask = new BuildSecondCachedMatchDriver(); + + try { + testResettableDriver(testTask, MockMatchStub.class, 3); + } catch (Exception e) { + e.printStackTrace(); + Assert.fail("Test caused an exception."); + } + + final int expCnt = valCnt1*valCnt2*Math.min(keyCnt1, keyCnt2); + Assert.assertEquals("Wrong result set size.", expCnt, this.outList.size()); + this.outList.clear(); + } + + @Test + public void testHash5MatchTask() { + int keyCnt1 = 20; + int valCnt1 = 20; + + int keyCnt2 = 20; + int valCnt2 = 20; + + addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false)); + addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false)); + addInputComparator(this.comparator1); + addInputComparator(this.comparator2); + getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); + setOutput(this.outList); + getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED); + getTaskConfig().setRelativeMemoryDriver(1.0f); + + BuildFirstCachedMatchDriver testTask = new BuildFirstCachedMatchDriver(); + + try { + testResettableDriver(testTask, MockMatchStub.class, 3); + } catch (Exception e) { + e.printStackTrace(); + Assert.fail("Test caused an exception."); + } + + final int expCnt = valCnt1*valCnt2*Math.min(keyCnt1, keyCnt2); + Assert.assertEquals("Wrong result set size.", expCnt, this.outList.size()); + this.outList.clear(); + } + + @Test + public void testFailingHashFirstMatchTask() { + int keyCnt1 = 20; + int valCnt1 = 20; + + int keyCnt2 = 20; + int valCnt2 = 20; + + addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false)); + addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false)); + addInputComparator(this.comparator1); + addInputComparator(this.comparator2); + getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); + setOutput(new NirvanaOutputList()); + getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED); + getTaskConfig().setRelativeMemoryDriver(1.0f); + + BuildFirstCachedMatchDriver testTask = new BuildFirstCachedMatchDriver(); + + try { + testResettableDriver(testTask, MockFailingMatchStub.class, 3); + Assert.fail("Function exception was not forwarded."); + } catch (ExpectedTestException etex) { + // good! + } catch (Exception e) { + e.printStackTrace(); + Assert.fail("Test caused an exception."); + } + } + + @Test + public void testFailingHashSecondMatchTask() { + int keyCnt1 = 20; + int valCnt1 = 20; + + int keyCnt2 = 20; + int valCnt2 = 20; + + addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false)); + addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false)); + addInputComparator(this.comparator1); + addInputComparator(this.comparator2); + getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); + setOutput(new NirvanaOutputList()); + getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED); + getTaskConfig().setRelativeMemoryDriver(1.0f); + + BuildSecondCachedMatchDriver testTask = new BuildSecondCachedMatchDriver(); + + try { + testResettableDriver(testTask, MockFailingMatchStub.class, 3); + Assert.fail("Function exception was not forwarded."); + } catch (ExpectedTestException etex) { + // good! + } catch (Exception e) { + e.printStackTrace(); + Assert.fail("Test caused an exception."); + } + } + + @Test + public void testCancelHashMatchTaskWhileBuildFirst() { + int keyCnt = 20; + int valCnt = 20; + + addInput(new DelayingInfinitiveInputIterator(100)); + addInput(new UniformRecordGenerator(keyCnt, valCnt, false)); + + addInputComparator(this.comparator1); + addInputComparator(this.comparator2); + + getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); + + setOutput(new NirvanaOutputList()); + + getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED); + getTaskConfig().setRelativeMemoryDriver(1.0f); + + final BuildFirstCachedMatchDriver testTask = new BuildFirstCachedMatchDriver(); + + final AtomicBoolean success = new AtomicBoolean(false); + + Thread taskRunner = new Thread() { + @Override + public void run() { + try { + testDriver(testTask, MockFailingMatchStub.class); + success.set(true); + } catch (Exception ie) { + ie.printStackTrace(); + } + } + }; + taskRunner.start(); + + TaskCancelThread tct = new TaskCancelThread(1, taskRunner, this); + tct.start(); + + try { + tct.join(); + taskRunner.join(); + } catch(InterruptedException ie) { + Assert.fail("Joining threads failed"); + } + + Assert.assertTrue("Test threw an exception even though it was properly canceled.", success.get()); + } + + @Test + public void testHashCancelMatchTaskWhileBuildSecond() { + int keyCnt = 20; + int valCnt = 20; + + addInput(new UniformRecordGenerator(keyCnt, valCnt, false)); + addInput(new DelayingInfinitiveInputIterator(100)); + addInputComparator(this.comparator1); + addInputComparator(this.comparator2); + getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); + setOutput(new NirvanaOutputList()); + getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED); + getTaskConfig().setRelativeMemoryDriver(1.0f); + + final BuildSecondCachedMatchDriver testTask = new BuildSecondCachedMatchDriver(); + + final AtomicBoolean success = new AtomicBoolean(false); + + Thread taskRunner = new Thread() { + @Override + public void run() { + try { + testDriver(testTask, MockMatchStub.class); + success.set(true); + } catch (Exception ie) { + ie.printStackTrace(); + } + } + }; + taskRunner.start(); + + TaskCancelThread tct = new TaskCancelThread(1, taskRunner, this); + tct.start(); + + try { + tct.join(); + taskRunner.join(); + } catch(InterruptedException ie) { + Assert.fail("Joining threads failed"); + } + + Assert.assertTrue("Test threw an exception even though it was properly canceled.", success.get()); + } + + @Test + public void testHashFirstCancelMatchTaskWhileMatching() { + int keyCnt = 20; + int valCnt = 20; + + addInput(new UniformRecordGenerator(keyCnt, valCnt, false)); + addInput(new UniformRecordGenerator(keyCnt, valCnt, false)); + addInputComparator(this.comparator1); + addInputComparator(this.comparator2); + getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); + setOutput(new NirvanaOutputList()); + getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST); + getTaskConfig().setRelativeMemoryDriver(1.0f); + + final BuildFirstCachedMatchDriver testTask = new BuildFirstCachedMatchDriver(); + + final AtomicBoolean success = new AtomicBoolean(false); + + Thread taskRunner = new Thread() { + @Override + public void run() { + try { + testDriver(testTask, MockMatchStub.class); + success.set(true); + } catch (Exception ie) { + ie.printStackTrace(); + } + } + }; + taskRunner.start(); + + TaskCancelThread tct = new TaskCancelThread(1, taskRunner, this); + tct.start(); + + try { + tct.join(); + taskRunner.join(); + } catch(InterruptedException ie) { + Assert.fail("Joining threads failed"); + } + + Assert.assertTrue("Test threw an exception even though it was properly canceled.", success.get()); + } + + @Test + public void testHashSecondCancelMatchTaskWhileMatching() { + int keyCnt = 20; + int valCnt = 20; + + addInput(new UniformRecordGenerator(keyCnt, valCnt, false)); + addInput(new UniformRecordGenerator(keyCnt, valCnt, false)); + addInputComparator(this.comparator1); + addInputComparator(this.comparator2); + getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get()); + setOutput(new NirvanaOutputList()); + getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND); + getTaskConfig().setRelativeMemoryDriver(1.0f); + + final BuildSecondCachedMatchDriver testTask = new BuildSecondCachedMatchDriver(); + + + final AtomicBoolean success = new AtomicBoolean(false); + + Thread taskRunner = new Thread() { + @Override + public void run() { + try { + testDriver(testTask, MockMatchStub.class); + success.set(true); + } catch (Exception ie) { + ie.printStackTrace(); + } + } + }; + taskRunner.start(); + + TaskCancelThread tct = new TaskCancelThread(1, taskRunner, this); + tct.start(); + + try { + tct.join(); + taskRunner.join(); + } catch(InterruptedException ie) { + Assert.fail("Joining threads failed"); + } + + Assert.assertTrue("Test threw an exception even though it was properly canceled.", success.get()); + } + + // ================================================================================================= + + public static final class MockMatchStub extends JoinFunction { + private static final long serialVersionUID = 1L; + + @Override + public void join(Record record1, Record record2, Collector out) throws Exception { + out.collect(record1); + } + } + + public static final class MockFailingMatchStub extends JoinFunction { + private static final long serialVersionUID = 1L; + + private int cnt = 0; + + @Override + public void join(Record record1, Record record2, Collector out) { + if (++this.cnt >= 10) { + throw new ExpectedTestException(); + } + + out.collect(record1); + } + } + + public static final class MockDelayingMatchStub extends JoinFunction { + private static final long serialVersionUID = 1L; + + @Override + public void join(Record record1, Record record2, Collector out) { + try { + Thread.sleep(100); + } catch (InterruptedException e) { } + } + } +} diff --git a/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/test/util/DriverTestBase.java b/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/test/util/DriverTestBase.java index 531382e4d11..41562048a25 100644 --- a/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/test/util/DriverTestBase.java +++ b/stratosphere-runtime/src/test/java/eu/stratosphere/pact/runtime/test/util/DriverTestBase.java @@ -35,6 +35,7 @@ import eu.stratosphere.api.java.typeutils.runtime.record.RecordSerializerFactory import eu.stratosphere.pact.runtime.sort.UnilateralSortMerger; import eu.stratosphere.pact.runtime.task.PactDriver; import eu.stratosphere.pact.runtime.task.PactTaskContext; +import eu.stratosphere.pact.runtime.task.ResettablePactDriver; import eu.stratosphere.pact.runtime.task.util.TaskConfig; import eu.stratosphere.types.Record; import eu.stratosphere.util.Collector; @@ -194,16 +195,48 @@ public class DriverTestBase implements PactTaskContext resDriver = (ResettablePactDriver) this.driver; + try { + resDriver.teardown(); + } catch (Throwable t) { + throw new Exception("Error while shutting down an iterative operator: " + t.getMessage(), t); + } + } + // drop exception, if the task was canceled if (this.running) { throw ex; } + } finally { driver.cleanup(); } } + @SuppressWarnings({"unchecked","rawtypes"}) + public void testResettableDriver(ResettablePactDriver driver, Class stubClass, int iterations) throws Exception { + + driver.setup(this); + + for(int i = 0; i < iterations; i++) { + + if(i == 0) { + driver.initialize(); + } + else { + driver.reset(); + } + + testDriver(driver, stubClass); + + } + + driver.teardown(); + } + public void cancel() throws Exception { this.running = false; this.driver.cancel(); -- GitLab