提交 47a68adc 编写于 作者: M Markus Holzemer 提交者: Stephan Ewen

[FLINK-836] Rework of the cached match driver

上级 99c888c7
......@@ -92,8 +92,6 @@ public abstract class OptimizerNode implements Visitable<OptimizerNode>, Estimat
protected boolean onDynamicPath;
protected boolean insideIteration;
protected List<PlanNode> cachedPlans; // cache candidates, because the may be accessed repeatedly
protected int[][] remappedKeys;
......@@ -501,14 +499,6 @@ public abstract class OptimizerNode implements Visitable<OptimizerNode>, Estimat
}
}
public boolean isInsideIteration() {
return insideIteration;
}
public void setInsideIteration(boolean insideIteration) {
this.insideIteration = insideIteration;
}
/**
* Checks whether this node has branching output. A node's output is branched, if it has more
* than one output connection.
......
......@@ -56,7 +56,7 @@ public class HashJoinBuildFirstProperties extends AbstractJoinDescriptor {
public DualInputPlanNode instantiate(Channel in1, Channel in2, TwoInputNode node) {
DriverStrategy strategy;
if(!in1.isOnDynamicPath() && in1.isInsideIteration() && in2.isInsideIteration()) {
if(!in1.isOnDynamicPath() && in2.isOnDynamicPath()) {
strategy = DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED;
}
else {
......
......@@ -53,7 +53,7 @@ public final class HashJoinBuildSecondProperties extends AbstractJoinDescriptor
public DualInputPlanNode instantiate(Channel in1, Channel in2, TwoInputNode node) {
DriverStrategy strategy;
if(!in2.isOnDynamicPath() && in1.isInsideIteration() && in2.isInsideIteration()) {
if(!in2.isOnDynamicPath() && in1.isOnDynamicPath()) {
strategy = DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED;
}
else {
......
......@@ -306,10 +306,6 @@ public class Channel implements EstimateProvider, Cloneable, DumpableConnection<
return this.source.isOnDynamicPath();
}
public boolean isInsideIteration() {
return this.source.isInsideIteration();
}
public int getCostWeight() {
return this.source.getCostWeight();
}
......
......@@ -423,10 +423,6 @@ public abstract class PlanNode implements Visitable<PlanNode>, DumpableNode<Plan
return this.template.getCostWeight();
}
public boolean isInsideIteration() {
return this.template.isInsideIteration();
}
// --------------------------------------------------------------------------------------------
/**
......
......@@ -72,7 +72,8 @@ public interface MemoryManager {
/**
* Returns the total size of memory.
* @return
*
* @return The total size of memory.
*/
long getMemorySize();
......@@ -88,8 +89,9 @@ public interface MemoryManager {
/**
* Computes the memory size of the fraction per slot.
* @param fraction
* @return
*
* @param fraction The fraction of the memory of the task slot.
* @return The number of pages corresponding to the memory fraction.
*/
long computeMemorySize(double fraction);
......
......@@ -394,8 +394,8 @@ public class DefaultMemoryManager implements MemoryManager {
}
private final int getRelativeNumPages(double fraction){
if(fraction < 0){
throw new IllegalArgumentException("The fraction of memory to allocate must not be negative.");
if (fraction <= 0 || fraction > 1) {
throw new IllegalArgumentException("The fraction of memory to allocate must within (0, 1].");
}
return (int)(this.totalNumPages * fraction / this.numberOfSlots);
......
......@@ -37,15 +37,19 @@ public class BuildFirstReOpenableHashMatchIterator<V1, V2, O> extends BuildFirst
TypeSerializer<V1> serializer1, TypeComparator<V1> comparator1,
TypeSerializer<V2> serializer2, TypeComparator<V2> comparator2,
TypePairComparator<V2, V1> pairComparator,
MemoryManager memManager, IOManager ioManager,
AbstractInvokable ownerTask, double memoryFraction)
throws MemoryAllocationException {
MemoryManager memManager,
IOManager ioManager,
AbstractInvokable ownerTask,
double memoryFraction)
throws MemoryAllocationException
{
super(firstInput, secondInput, serializer1, comparator1, serializer2,
comparator2, pairComparator, memManager, ioManager, ownerTask,
memoryFraction);
reopenHashTable = (ReOpenableMutableHashTable<V1, V2>) hashJoin;
}
@Override
public <BT, PT> MutableHashTable<BT, PT> getHashJoin(TypeSerializer<BT> buildSideSerializer, TypeComparator<BT> buildSideComparator,
TypeSerializer<PT> probeSideSerializer, TypeComparator<PT> probeSideComparator,
TypePairComparator<PT, BT> pairComparator,
......
......@@ -34,9 +34,9 @@ import eu.stratosphere.util.MutableObjectIterator;
* An implementation of the {@link eu.stratosphere.pact.runtime.task.util.JoinTaskIterator} that uses a hybrid-hash-join
* internally to match the records with equal key. The build side of the hash is the second input of the match.
*/
public final class BuildSecondHashMatchIterator<V1, V2, O> implements JoinTaskIterator<V1, V2, O> {
public class BuildSecondHashMatchIterator<V1, V2, O> implements JoinTaskIterator<V1, V2, O> {
private final MutableHashTable<V2, V1> hashJoin;
protected final MutableHashTable<V2, V1> hashJoin;
private final V2 nextBuildSideObject;
......@@ -44,7 +44,7 @@ public final class BuildSecondHashMatchIterator<V1, V2, O> implements JoinTaskIt
private final V1 probeCopy;
private final TypeSerializer<V1> probeSideSerializer;
protected final TypeSerializer<V1> probeSideSerializer;
private final MemoryManager memManager;
......
/***********************************************************************************************************************
* Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
**********************************************************************************************************************/
package eu.stratosphere.pact.runtime.hash;
import java.io.IOException;
import java.util.List;
import eu.stratosphere.api.common.typeutils.TypeComparator;
import eu.stratosphere.api.common.typeutils.TypePairComparator;
import eu.stratosphere.api.common.typeutils.TypeSerializer;
import eu.stratosphere.core.memory.MemorySegment;
import eu.stratosphere.nephele.services.iomanager.IOManager;
import eu.stratosphere.nephele.services.memorymanager.MemoryAllocationException;
import eu.stratosphere.nephele.services.memorymanager.MemoryManager;
import eu.stratosphere.nephele.template.AbstractInvokable;
import eu.stratosphere.util.MutableObjectIterator;
public class BuildSecondReOpenableHashMatchIterator<V1, V2, O> extends BuildSecondHashMatchIterator<V1, V2, O> {
private final ReOpenableMutableHashTable<V2, V1> reopenHashTable;
public BuildSecondReOpenableHashMatchIterator(
MutableObjectIterator<V1> firstInput,
MutableObjectIterator<V2> secondInput,
TypeSerializer<V1> serializer1, TypeComparator<V1> comparator1,
TypeSerializer<V2> serializer2, TypeComparator<V2> comparator2,
TypePairComparator<V1, V2> pairComparator,
MemoryManager memManager,
IOManager ioManager,
AbstractInvokable ownerTask,
double memoryFraction)
throws MemoryAllocationException
{
super(firstInput, secondInput, serializer1, comparator1, serializer2,
comparator2, pairComparator, memManager, ioManager, ownerTask, memoryFraction);
reopenHashTable = (ReOpenableMutableHashTable<V2, V1>) hashJoin;
}
@Override
public <BT, PT> MutableHashTable<BT, PT> getHashJoin(TypeSerializer<BT> buildSideSerializer, TypeComparator<BT> buildSideComparator,
TypeSerializer<PT> probeSideSerializer, TypeComparator<PT> probeSideComparator,
TypePairComparator<PT, BT> pairComparator,
MemoryManager memManager, IOManager ioManager, AbstractInvokable ownerTask, double memoryFraction)
throws MemoryAllocationException
{
final int numPages = memManager.computeNumberOfPages(memoryFraction);
final List<MemorySegment> memorySegments = memManager.allocatePages(ownerTask, numPages);
return new ReOpenableMutableHashTable<BT, PT>(buildSideSerializer, probeSideSerializer, buildSideComparator, probeSideComparator, pairComparator, memorySegments, ioManager);
}
/**
* Set new input for probe side
* @throws IOException
*/
public void reopenProbe(MutableObjectIterator<V1> probeInput) throws IOException {
reopenHashTable.reopenProbe(probeInput);
}
}
......@@ -34,8 +34,6 @@ import eu.stratosphere.util.MutableObjectIterator;
*/
public class AsynchronousPartialSorter<E> extends UnilateralSortMerger<E> {
private static final int MAX_MEM_PER_PARTIAL_SORT = 64 * 1024 * 0124;
private BufferQueueIterator bufferIterator;
// ------------------------------------------------------------------------
......@@ -62,11 +60,7 @@ public class AsynchronousPartialSorter<E> extends UnilateralSortMerger<E> {
double memoryFraction)
throws IOException, MemoryAllocationException
{
super(memoryManager, null, input, parentTask, serializerFactory, comparator, memoryFraction,
memoryManager.computeNumberOfPages(memoryFraction) < 2 * MIN_NUM_SORT_MEM_SEGMENTS ? 1 :
Math.max((int) Math.ceil(((double) memoryManager.computeMemorySize(memoryFraction)) /
MAX_MEM_PER_PARTIAL_SORT), 2),
2, 0.0f, true);
super(memoryManager, null, input, parentTask, serializerFactory, comparator, memoryFraction, 1, 2, 0.0f, true);
}
......@@ -101,11 +95,6 @@ public class AsynchronousPartialSorter<E> extends UnilateralSortMerger<E> {
// ------------------------------------------------------------------------
/**
* This class implements an iterator over values from a {@link eu.stratosphere.pact.runtime.sort.BufferSortable}.
* The iterator returns the values of a given
* interval.
*/
private final class BufferQueueIterator implements MutableObjectIterator<E> {
private final CircularQueues<E> queues;
......
......@@ -239,7 +239,7 @@ public class UnilateralSortMerger<E> implements Sorter<E> {
* @param maxNumFileHandles The maximum number of files to be merged at once.
* @param startSpillingFraction The faction of the buffers that have to be filled before the spilling thread
* actually begins spilling data to disk.
* @param noSpilling When set to true, no memory will be allocated for writing and no spilling thread
* @param noSpillingMemory When set to true, no memory will be allocated for writing and no spilling thread
* will be spawned.
*
* @throws IOException Thrown, if an error occurs initializing the resources for external sorting.
......
......@@ -13,26 +13,20 @@
package eu.stratosphere.pact.runtime.task;
import java.util.List;
import eu.stratosphere.api.common.functions.GenericJoiner;
import eu.stratosphere.api.common.typeutils.TypeComparator;
import eu.stratosphere.api.common.typeutils.TypePairComparatorFactory;
import eu.stratosphere.api.common.typeutils.TypeSerializer;
import eu.stratosphere.core.memory.MemorySegment;
import eu.stratosphere.pact.runtime.hash.MutableHashTable;
import eu.stratosphere.pact.runtime.hash.BuildFirstReOpenableHashMatchIterator;
import eu.stratosphere.pact.runtime.hash.BuildSecondReOpenableHashMatchIterator;
import eu.stratosphere.pact.runtime.task.util.JoinTaskIterator;
import eu.stratosphere.pact.runtime.task.util.TaskConfig;
import eu.stratosphere.pact.runtime.util.EmptyMutableObjectIterator;
import eu.stratosphere.util.Collector;
import eu.stratosphere.util.MutableObjectIterator;
public abstract class AbstractCachedBuildSideMatchDriver<IT1, IT2, OT> extends MatchDriver<IT1, IT2, OT> implements ResettablePactDriver<GenericJoiner<IT1, IT2, OT>, OT> {
/**
* We keep it without generic parameters, because they vary depending on which input is the build side.
*/
protected volatile MutableHashTable<?, ?> hashJoin;
private volatile JoinTaskIterator<IT1, IT2, OT> matchIterator;
private final int buildSideIndex;
......@@ -67,23 +61,39 @@ public abstract class AbstractCachedBuildSideMatchDriver<IT1, IT2, OT> extends M
TypePairComparatorFactory<IT1, IT2> pairComparatorFactory =
this.taskContext.getTaskConfig().getPairComparatorFactory(this.taskContext.getUserCodeClassLoader());
int numMemoryPages = this.taskContext.getMemoryManager().computeNumberOfPages(config.getRelativeMemoryDriver());
List<MemorySegment> memSegments = this.taskContext.getMemoryManager().allocatePages(
this.taskContext.getOwningNepheleTask(), numMemoryPages);
double availableMemory = config.getRelativeMemoryDriver();
if (buildSideIndex == 0 && probeSideIndex == 1) {
MutableHashTable<IT1, IT2> hashJoin = new MutableHashTable<IT1, IT2>(serializer1, serializer2, comparator1, comparator2,
pairComparatorFactory.createComparator21(comparator1, comparator2), memSegments, this.taskContext.getIOManager());
this.hashJoin = hashJoin;
hashJoin.open(input1, EmptyMutableObjectIterator.<IT2>get());
matchIterator =
new BuildFirstReOpenableHashMatchIterator<IT1, IT2, OT>(input1, input2,
serializer1, comparator1,
serializer2, comparator2,
pairComparatorFactory.createComparator21(comparator1, comparator2),
this.taskContext.getMemoryManager(),
this.taskContext.getIOManager(),
this.taskContext.getOwningNepheleTask(),
availableMemory
);
} else if (buildSideIndex == 1 && probeSideIndex == 0) {
MutableHashTable<IT2, IT1> hashJoin = new MutableHashTable<IT2, IT1>(serializer2, serializer1, comparator2, comparator1,
pairComparatorFactory.createComparator12(comparator1, comparator2), memSegments, this.taskContext.getIOManager());
this.hashJoin = hashJoin;
hashJoin.open(input2, EmptyMutableObjectIterator.<IT1>get());
matchIterator =
new BuildSecondReOpenableHashMatchIterator<IT1, IT2, OT>(input1, input2,
serializer1, comparator1,
serializer2, comparator2,
pairComparatorFactory.createComparator12(comparator1, comparator2),
this.taskContext.getMemoryManager(),
this.taskContext.getIOManager(),
this.taskContext.getOwningNepheleTask(),
availableMemory
);
} else {
throw new Exception("Error: Inconcistent setup for repeatable hash join driver.");
}
this.matchIterator.open();
}
@Override
......@@ -98,63 +108,17 @@ public abstract class AbstractCachedBuildSideMatchDriver<IT1, IT2, OT> extends M
final Collector<OT> collector = this.taskContext.getOutputCollector();
if (buildSideIndex == 0) {
final TypeSerializer<IT1> buildSideSerializer = taskContext.<IT1> getInputSerializer(0).getSerializer();
final TypeSerializer<IT2> probeSideSerializer = taskContext.<IT2> getInputSerializer(1).getSerializer();
IT1 buildSideRecordFirst;
IT1 buildSideRecordOther;
IT2 probeSideRecord;
IT2 probeSideRecordCopy;
final IT1 buildSideRecordFirstReuse = buildSideSerializer.createInstance();
final IT1 buildSideRecordOtherReuse = buildSideSerializer.createInstance();
final IT2 probeSideRecordReuse = probeSideSerializer.createInstance();
final IT2 probeSideRecordCopyReuse = probeSideSerializer.createInstance();
final BuildFirstReOpenableHashMatchIterator<IT1, IT2, OT> matchIterator = (BuildFirstReOpenableHashMatchIterator<IT1, IT2, OT>) this.matchIterator;
@SuppressWarnings("unchecked")
final MutableHashTable<IT1, IT2> join = (MutableHashTable<IT1, IT2>) this.hashJoin;
while (this.running && matchIterator != null && matchIterator.callWithNextKey(matchStub, collector));
final MutableObjectIterator<IT2> probeSideInput = taskContext.<IT2>getInput(1);
while (this.running && ((probeSideRecord = probeSideInput.next(probeSideRecordReuse)) != null)) {
final MutableHashTable.HashBucketIterator<IT1, IT2> bucket = join.getMatchesFor(probeSideRecord);
if ((buildSideRecordFirst = bucket.next(buildSideRecordFirstReuse)) != null) {
while ((buildSideRecordOther = bucket.next(buildSideRecordOtherReuse)) != null) {
probeSideRecordCopy = probeSideSerializer.copy(probeSideRecord, probeSideRecordCopyReuse);
matchStub.join(buildSideRecordOther, probeSideRecordCopy, collector);
}
matchStub.join(buildSideRecordFirst, probeSideRecord, collector);
}
}
} else if (buildSideIndex == 1) {
final TypeSerializer<IT2> buildSideSerializer = taskContext.<IT2>getInputSerializer(1).getSerializer();
final TypeSerializer<IT1> probeSideSerializer = taskContext.<IT1>getInputSerializer(0).getSerializer();
IT2 buildSideRecordFirst;
IT2 buildSideRecordOther;
IT1 probeSideRecord;
IT1 probeSideRecordCopy;
final IT2 buildSideRecordFirstReuse = buildSideSerializer.createInstance();
final IT2 buildSideRecordOtherReuse = buildSideSerializer.createInstance();
final IT1 probeSideRecordReuse = probeSideSerializer.createInstance();
final IT1 probeSideRecordCopyReuse = probeSideSerializer.createInstance();
@SuppressWarnings("unchecked")
final MutableHashTable<IT2, IT1> join = (MutableHashTable<IT2, IT1>) this.hashJoin;
final BuildSecondReOpenableHashMatchIterator<IT1, IT2, OT> matchIterator = (BuildSecondReOpenableHashMatchIterator<IT1, IT2, OT>) this.matchIterator;
final MutableObjectIterator<IT1> probeSideInput = taskContext.<IT1>getInput(0);
while (this.running && matchIterator != null && matchIterator.callWithNextKey(matchStub, collector));
while (this.running && ((probeSideRecord = probeSideInput.next(probeSideRecordReuse)) != null)) {
final MutableHashTable.HashBucketIterator<IT2, IT1> bucket = join.getMatchesFor(probeSideRecord);
if ((buildSideRecordFirst = bucket.next(buildSideRecordFirstReuse)) != null) {
while ((buildSideRecordOther = bucket.next(buildSideRecordOtherReuse)) != null) {
probeSideRecordCopy = probeSideSerializer.copy(probeSideRecord, probeSideRecordCopyReuse);
matchStub.join(probeSideRecordCopy, buildSideRecordOther, collector);
}
matchStub.join(probeSideRecord, buildSideRecordFirst, collector);
}
}
} else {
throw new Exception();
}
......@@ -164,21 +128,34 @@ public abstract class AbstractCachedBuildSideMatchDriver<IT1, IT2, OT> extends M
public void cleanup() throws Exception {}
@Override
public void reset() throws Exception {}
public void reset() throws Exception {
MutableObjectIterator<IT1> input1 = this.taskContext.getInput(0);
MutableObjectIterator<IT2> input2 = this.taskContext.getInput(1);
if (buildSideIndex == 0 && probeSideIndex == 1) {
final BuildFirstReOpenableHashMatchIterator<IT1, IT2, OT> matchIterator = (BuildFirstReOpenableHashMatchIterator<IT1, IT2, OT>) this.matchIterator;
matchIterator.reopenProbe(input2);
}
else {
final BuildSecondReOpenableHashMatchIterator<IT1, IT2, OT> matchIterator = (BuildSecondReOpenableHashMatchIterator<IT1, IT2, OT>) this.matchIterator;
matchIterator.reopenProbe(input1);
}
}
@Override
public void teardown() {
MutableHashTable<?, ?> ht = this.hashJoin;
if (ht != null) {
ht.close();
this.running = false;
if (this.matchIterator != null) {
this.matchIterator.close();
}
}
@Override
public void cancel() {
this.running = false;
if (this.hashJoin != null) {
this.hashJoin.close();
if (this.matchIterator != null) {
this.matchIterator.abort();
}
}
}
......@@ -531,6 +531,16 @@ public class RegularPactTask<S extends Function, OT> extends AbstractInvokable i
}
catch (Throwable t) {}
}
// if resettable driver invoke treardown
if (this.driver instanceof ResettablePactDriver) {
final ResettablePactDriver<?, ?> resDriver = (ResettablePactDriver<?, ?>) this.driver;
try {
resDriver.teardown();
} catch (Throwable t) {
throw new Exception("Error while shutting down an iterative operator: " + t.getMessage(), t);
}
}
RegularPactTask.cancelChainedTasks(this.chainedTasks);
......
......@@ -128,7 +128,7 @@ public class AsynchonousPartialSorterITCase {
// merge iterator
LOG.debug("Initializing sortmerger...");
Sorter<Record> sorter = new AsynchronousPartialSorter<Record>(this.memoryManager, source,
this.parentTask, this.serializer, this.comparator, 1.0);
this.parentTask, this.serializer, this.comparator, 0.2);
runPartialSorter(sorter, NUM_RECORDS, 2);
}
......@@ -151,9 +151,9 @@ public class AsynchonousPartialSorterITCase {
// merge iterator
LOG.debug("Initializing sortmerger...");
Sorter<Record> sorter = new AsynchronousPartialSorter<Record>(this.memoryManager, source,
this.parentTask, this.serializer, this.comparator, 1.0);
this.parentTask, this.serializer, this.comparator, 0.15);
runPartialSorter(sorter, NUM_RECORDS, 28);
runPartialSorter(sorter, NUM_RECORDS, 27);
}
catch (Exception t) {
t.printStackTrace();
......
/***********************************************************************************************************************
* Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu)
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
**********************************************************************************************************************/
package eu.stratosphere.pact.runtime.task;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.Assert;
import org.junit.Test;
import eu.stratosphere.api.common.functions.GenericJoiner;
import eu.stratosphere.api.java.record.functions.JoinFunction;
import eu.stratosphere.api.java.typeutils.runtime.record.RecordComparator;
import eu.stratosphere.api.java.typeutils.runtime.record.RecordPairComparatorFactory;
import eu.stratosphere.pact.runtime.test.util.DelayingInfinitiveInputIterator;
import eu.stratosphere.pact.runtime.test.util.DriverTestBase;
import eu.stratosphere.pact.runtime.test.util.ExpectedTestException;
import eu.stratosphere.pact.runtime.test.util.NirvanaOutputList;
import eu.stratosphere.pact.runtime.test.util.TaskCancelThread;
import eu.stratosphere.pact.runtime.test.util.UniformRecordGenerator;
import eu.stratosphere.types.IntValue;
import eu.stratosphere.types.Key;
import eu.stratosphere.types.Record;
import eu.stratosphere.util.Collector;
public class CachedMatchTaskTest extends DriverTestBase<GenericJoiner<Record, Record, Record>>
{
private static final long HASH_MEM = 6*1024*1024;
private static final long SORT_MEM = 3*1024*1024;
@SuppressWarnings("unchecked")
private final RecordComparator comparator1 = new RecordComparator(
new int[]{0}, (Class<? extends Key<?>>[])new Class[]{ IntValue.class });
@SuppressWarnings("unchecked")
private final RecordComparator comparator2 = new RecordComparator(
new int[]{0}, (Class<? extends Key<?>>[])new Class[]{ IntValue.class });
private final List<Record> outList = new ArrayList<Record>();
public CachedMatchTaskTest() {
super(HASH_MEM, 2, SORT_MEM);
}
@Test
public void testHash1MatchTask() {
int keyCnt1 = 20;
int valCnt1 = 1;
int keyCnt2 = 10;
int valCnt2 = 2;
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED);
getTaskConfig().setRelativeMemoryDriver(1.0f);
BuildFirstCachedMatchDriver<Record, Record, Record> testTask = new BuildFirstCachedMatchDriver<Record, Record, Record>();
try {
testResettableDriver(testTask, MockMatchStub.class, 3);
} catch (Exception e) {
e.printStackTrace();
Assert.fail("Test caused an exception.");
}
final int expCnt = valCnt1*valCnt2*Math.min(keyCnt1, keyCnt2);
Assert.assertEquals("Wrong result set size.", expCnt, this.outList.size());
this.outList.clear();
}
@Test
public void testHash2MatchTask() {
int keyCnt1 = 20;
int valCnt1 = 1;
int keyCnt2 = 20;
int valCnt2 = 1;
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED);
getTaskConfig().setRelativeMemoryDriver(1.0f);
BuildSecondCachedMatchDriver<Record, Record, Record> testTask = new BuildSecondCachedMatchDriver<Record, Record, Record>();
try {
testResettableDriver(testTask, MockMatchStub.class, 3);
} catch (Exception e) {
e.printStackTrace();
Assert.fail("Test caused an exception.");
}
final int expCnt = valCnt1*valCnt2*Math.min(keyCnt1, keyCnt2);
Assert.assertEquals("Wrong result set size.", expCnt, this.outList.size());
this.outList.clear();
}
@Test
public void testHash3MatchTask() {
int keyCnt1 = 20;
int valCnt1 = 1;
int keyCnt2 = 20;
int valCnt2 = 20;
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED);
getTaskConfig().setRelativeMemoryDriver(1.0f);
BuildFirstCachedMatchDriver<Record, Record, Record> testTask = new BuildFirstCachedMatchDriver<Record, Record, Record>();
try {
testResettableDriver(testTask, MockMatchStub.class, 3);
} catch (Exception e) {
e.printStackTrace();
Assert.fail("Test caused an exception.");
}
final int expCnt = valCnt1*valCnt2*Math.min(keyCnt1, keyCnt2);
Assert.assertEquals("Wrong result set size.", expCnt, this.outList.size());
this.outList.clear();
}
@Test
public void testHash4MatchTask() {
int keyCnt1 = 20;
int valCnt1 = 20;
int keyCnt2 = 20;
int valCnt2 = 1;
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED);
getTaskConfig().setRelativeMemoryDriver(1.0f);
BuildSecondCachedMatchDriver<Record, Record, Record> testTask = new BuildSecondCachedMatchDriver<Record, Record, Record>();
try {
testResettableDriver(testTask, MockMatchStub.class, 3);
} catch (Exception e) {
e.printStackTrace();
Assert.fail("Test caused an exception.");
}
final int expCnt = valCnt1*valCnt2*Math.min(keyCnt1, keyCnt2);
Assert.assertEquals("Wrong result set size.", expCnt, this.outList.size());
this.outList.clear();
}
@Test
public void testHash5MatchTask() {
int keyCnt1 = 20;
int valCnt1 = 20;
int keyCnt2 = 20;
int valCnt2 = 20;
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED);
getTaskConfig().setRelativeMemoryDriver(1.0f);
BuildFirstCachedMatchDriver<Record, Record, Record> testTask = new BuildFirstCachedMatchDriver<Record, Record, Record>();
try {
testResettableDriver(testTask, MockMatchStub.class, 3);
} catch (Exception e) {
e.printStackTrace();
Assert.fail("Test caused an exception.");
}
final int expCnt = valCnt1*valCnt2*Math.min(keyCnt1, keyCnt2);
Assert.assertEquals("Wrong result set size.", expCnt, this.outList.size());
this.outList.clear();
}
@Test
public void testFailingHashFirstMatchTask() {
int keyCnt1 = 20;
int valCnt1 = 20;
int keyCnt2 = 20;
int valCnt2 = 20;
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(new NirvanaOutputList());
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED);
getTaskConfig().setRelativeMemoryDriver(1.0f);
BuildFirstCachedMatchDriver<Record, Record, Record> testTask = new BuildFirstCachedMatchDriver<Record, Record, Record>();
try {
testResettableDriver(testTask, MockFailingMatchStub.class, 3);
Assert.fail("Function exception was not forwarded.");
} catch (ExpectedTestException etex) {
// good!
} catch (Exception e) {
e.printStackTrace();
Assert.fail("Test caused an exception.");
}
}
@Test
public void testFailingHashSecondMatchTask() {
int keyCnt1 = 20;
int valCnt1 = 20;
int keyCnt2 = 20;
int valCnt2 = 20;
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(new NirvanaOutputList());
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED);
getTaskConfig().setRelativeMemoryDriver(1.0f);
BuildSecondCachedMatchDriver<Record, Record, Record> testTask = new BuildSecondCachedMatchDriver<Record, Record, Record>();
try {
testResettableDriver(testTask, MockFailingMatchStub.class, 3);
Assert.fail("Function exception was not forwarded.");
} catch (ExpectedTestException etex) {
// good!
} catch (Exception e) {
e.printStackTrace();
Assert.fail("Test caused an exception.");
}
}
@Test
public void testCancelHashMatchTaskWhileBuildFirst() {
int keyCnt = 20;
int valCnt = 20;
addInput(new DelayingInfinitiveInputIterator(100));
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(new NirvanaOutputList());
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED);
getTaskConfig().setRelativeMemoryDriver(1.0f);
final BuildFirstCachedMatchDriver<Record, Record, Record> testTask = new BuildFirstCachedMatchDriver<Record, Record, Record>();
final AtomicBoolean success = new AtomicBoolean(false);
Thread taskRunner = new Thread() {
@Override
public void run() {
try {
testDriver(testTask, MockFailingMatchStub.class);
success.set(true);
} catch (Exception ie) {
ie.printStackTrace();
}
}
};
taskRunner.start();
TaskCancelThread tct = new TaskCancelThread(1, taskRunner, this);
tct.start();
try {
tct.join();
taskRunner.join();
} catch(InterruptedException ie) {
Assert.fail("Joining threads failed");
}
Assert.assertTrue("Test threw an exception even though it was properly canceled.", success.get());
}
@Test
public void testHashCancelMatchTaskWhileBuildSecond() {
int keyCnt = 20;
int valCnt = 20;
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInput(new DelayingInfinitiveInputIterator(100));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(new NirvanaOutputList());
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED);
getTaskConfig().setRelativeMemoryDriver(1.0f);
final BuildSecondCachedMatchDriver<Record, Record, Record> testTask = new BuildSecondCachedMatchDriver<Record, Record, Record>();
final AtomicBoolean success = new AtomicBoolean(false);
Thread taskRunner = new Thread() {
@Override
public void run() {
try {
testDriver(testTask, MockMatchStub.class);
success.set(true);
} catch (Exception ie) {
ie.printStackTrace();
}
}
};
taskRunner.start();
TaskCancelThread tct = new TaskCancelThread(1, taskRunner, this);
tct.start();
try {
tct.join();
taskRunner.join();
} catch(InterruptedException ie) {
Assert.fail("Joining threads failed");
}
Assert.assertTrue("Test threw an exception even though it was properly canceled.", success.get());
}
@Test
public void testHashFirstCancelMatchTaskWhileMatching() {
int keyCnt = 20;
int valCnt = 20;
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(new NirvanaOutputList());
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST);
getTaskConfig().setRelativeMemoryDriver(1.0f);
final BuildFirstCachedMatchDriver<Record, Record, Record> testTask = new BuildFirstCachedMatchDriver<Record, Record, Record>();
final AtomicBoolean success = new AtomicBoolean(false);
Thread taskRunner = new Thread() {
@Override
public void run() {
try {
testDriver(testTask, MockMatchStub.class);
success.set(true);
} catch (Exception ie) {
ie.printStackTrace();
}
}
};
taskRunner.start();
TaskCancelThread tct = new TaskCancelThread(1, taskRunner, this);
tct.start();
try {
tct.join();
taskRunner.join();
} catch(InterruptedException ie) {
Assert.fail("Joining threads failed");
}
Assert.assertTrue("Test threw an exception even though it was properly canceled.", success.get());
}
@Test
public void testHashSecondCancelMatchTaskWhileMatching() {
int keyCnt = 20;
int valCnt = 20;
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(new NirvanaOutputList());
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND);
getTaskConfig().setRelativeMemoryDriver(1.0f);
final BuildSecondCachedMatchDriver<Record, Record, Record> testTask = new BuildSecondCachedMatchDriver<Record, Record, Record>();
final AtomicBoolean success = new AtomicBoolean(false);
Thread taskRunner = new Thread() {
@Override
public void run() {
try {
testDriver(testTask, MockMatchStub.class);
success.set(true);
} catch (Exception ie) {
ie.printStackTrace();
}
}
};
taskRunner.start();
TaskCancelThread tct = new TaskCancelThread(1, taskRunner, this);
tct.start();
try {
tct.join();
taskRunner.join();
} catch(InterruptedException ie) {
Assert.fail("Joining threads failed");
}
Assert.assertTrue("Test threw an exception even though it was properly canceled.", success.get());
}
// =================================================================================================
public static final class MockMatchStub extends JoinFunction {
private static final long serialVersionUID = 1L;
@Override
public void join(Record record1, Record record2, Collector<Record> out) throws Exception {
out.collect(record1);
}
}
public static final class MockFailingMatchStub extends JoinFunction {
private static final long serialVersionUID = 1L;
private int cnt = 0;
@Override
public void join(Record record1, Record record2, Collector<Record> out) {
if (++this.cnt >= 10) {
throw new ExpectedTestException();
}
out.collect(record1);
}
}
public static final class MockDelayingMatchStub extends JoinFunction {
private static final long serialVersionUID = 1L;
@Override
public void join(Record record1, Record record2, Collector<Record> out) {
try {
Thread.sleep(100);
} catch (InterruptedException e) { }
}
}
}
......@@ -35,6 +35,7 @@ import eu.stratosphere.api.java.typeutils.runtime.record.RecordSerializerFactory
import eu.stratosphere.pact.runtime.sort.UnilateralSortMerger;
import eu.stratosphere.pact.runtime.task.PactDriver;
import eu.stratosphere.pact.runtime.task.PactTaskContext;
import eu.stratosphere.pact.runtime.task.ResettablePactDriver;
import eu.stratosphere.pact.runtime.task.util.TaskConfig;
import eu.stratosphere.types.Record;
import eu.stratosphere.util.Collector;
......@@ -194,16 +195,48 @@ public class DriverTestBase<S extends Function> implements PactTaskContext<S, Re
catch (Throwable t) {}
}
// if resettable driver invoke treardown
if (this.driver instanceof ResettablePactDriver) {
final ResettablePactDriver<?, ?> resDriver = (ResettablePactDriver<?, ?>) this.driver;
try {
resDriver.teardown();
} catch (Throwable t) {
throw new Exception("Error while shutting down an iterative operator: " + t.getMessage(), t);
}
}
// drop exception, if the task was canceled
if (this.running) {
throw ex;
}
}
finally {
driver.cleanup();
}
}
@SuppressWarnings({"unchecked","rawtypes"})
public void testResettableDriver(ResettablePactDriver driver, Class stubClass, int iterations) throws Exception {
driver.setup(this);
for(int i = 0; i < iterations; i++) {
if(i == 0) {
driver.initialize();
}
else {
driver.reset();
}
testDriver(driver, stubClass);
}
driver.teardown();
}
public void cancel() throws Exception {
this.running = false;
this.driver.cancel();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册