提交 e5731e0e 编写于 作者: F Fabian Hueske

[FLINK-1073] Enables sorted group input for GroupReduce combiners.

- GroupReduceCombineDriver uses separate comparators for sorting and grouping.
- Adding support for multiple comparators for a single input driver required some refactoring.

This closes #109
上级 4b6d2c57
......@@ -93,9 +93,13 @@ public final class GroupReduceWithCombineProperties extends OperatorDescriptorSi
combinerNode.setDegreeOfParallelism(in.getSource().getDegreeOfParallelism());
SingleInputPlanNode combiner = new SingleInputPlanNode(combinerNode, "Combine("+node.getPactContract()
.getName()+")", toCombiner, DriverStrategy.SORTED_GROUP_COMBINE, this.keyList);
.getName()+")", toCombiner, DriverStrategy.SORTED_GROUP_COMBINE);
combiner.setCosts(new Costs(0, 0));
combiner.initProperties(toCombiner.getGlobalProperties(), toCombiner.getLocalProperties());
// set sorting comparator key info
combiner.setDriverKeyInfo(in.getLocalStrategyKeys(), in.getLocalStrategySortOrder(), 0);
// set grouping comparator key info
combiner.setDriverKeyInfo(this.keyList, 1);
Channel toReducer = new Channel(combiner);
toReducer.setShipStrategy(in.getShipStrategy(), in.getShipStrategyKeys(), in.getShipStrategySortOrder());
......
......@@ -51,8 +51,14 @@ public final class PartialGroupProperties extends OperatorDescriptorSingle {
GroupReduceNode combinerNode = new GroupReduceNode((GroupReduceOperatorBase<?, ?, ?>) node.getPactContract());
combinerNode.setDegreeOfParallelism(in.getSource().getDegreeOfParallelism());
return new SingleInputPlanNode(combinerNode, "Combine("+node.getPactContract().getName()+")", in,
DriverStrategy.SORTED_GROUP_COMBINE, this.keyList);
SingleInputPlanNode combiner = new SingleInputPlanNode(combinerNode, "Combine("+node.getPactContract().getName()+")", in,
DriverStrategy.SORTED_GROUP_COMBINE);
// sorting key info
combiner.setDriverKeyInfo(in.getLocalStrategyKeys(), in.getLocalStrategySortOrder(), 0);
// set grouping comparator key info
combiner.setDriverKeyInfo(this.keyList, 1);
return combiner;
}
@Override
......
......@@ -29,6 +29,7 @@ import java.util.List;
import org.apache.flink.api.common.operators.util.FieldList;
import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
import org.apache.flink.compiler.CompilerException;
import org.apache.flink.compiler.dag.OptimizerNode;
import org.apache.flink.compiler.dag.SingleInputNode;
import org.apache.flink.runtime.operators.DamBehavior;
......@@ -43,11 +44,11 @@ public class SingleInputPlanNode extends PlanNode {
protected final Channel input;
protected final FieldList keys;
protected final FieldList[] driverKeys;
protected final boolean[] sortOrders;
protected final boolean[][] driverSortOrders;
private TypeComparatorFactory<?> comparator;
private TypeComparatorFactory<?>[] comparators;
public Object postPassHelper;
......@@ -68,8 +69,15 @@ public class SingleInputPlanNode extends PlanNode {
{
super(template, nodeName, driverStrategy);
this.input = input;
this.keys = driverKeyFields;
this.sortOrders = driverSortOrders;
this.comparators = new TypeComparatorFactory<?>[driverStrategy.getNumRequiredComparators()];
this.driverKeys = new FieldList[driverStrategy.getNumRequiredComparators()];
this.driverSortOrders = new boolean[driverStrategy.getNumRequiredComparators()][];
if(driverStrategy.getNumRequiredComparators() > 0) {
this.driverKeys[0] = driverKeyFields;
this.driverSortOrders[0] = driverSortOrders;
}
if (this.input.getShipStrategy() == ShipStrategyType.BROADCAST) {
this.input.setReplicationFactor(getDegreeOfParallelism());
......@@ -81,6 +89,7 @@ public class SingleInputPlanNode extends PlanNode {
} else if (predNode.branchPlan != null) {
this.branchPlan.putAll(predNode.branchPlan);
}
}
// --------------------------------------------------------------------------------------------
......@@ -111,30 +120,71 @@ public class SingleInputPlanNode extends PlanNode {
return this.input.getSource();
}
public FieldList getKeys() {
return this.keys;
/**
* Sets the key field indexes for the specified driver comparator.
*
* @param keys The key field indexes for the specified driver comparator.
* @param id The ID of the driver comparator.
*/
public void setDriverKeyInfo(FieldList keys, int id) {
this.setDriverKeyInfo(keys, getTrueArray(keys.size()), id);
}
/**
* Sets the key field information for the specified driver comparator.
*
* @param keys The key field indexes for the specified driver comparator.
* @param sortOrder The key sort order for the specified driver comparator.
* @param id The ID of the driver comparator.
*/
public void setDriverKeyInfo(FieldList keys, boolean[] sortOrder, int id) {
if(id < 0 || id >= driverKeys.length) {
throw new CompilerException("Invalid id for driver key information. DriverStrategy requires only "
+super.getDriverStrategy().getNumRequiredComparators()+" comparators.");
}
this.driverKeys[id] = keys;
this.driverSortOrders[id] = sortOrder;
}
/**
* Gets the key field indexes for the specified driver comparator.
*
* @param id The id of the driver comparator for which the key field indexes are requested.
* @return The key field indexes of the specified driver comparator.
*/
public FieldList getKeys(int id) {
return this.driverKeys[id];
}
public boolean[] getSortOrders() {
return sortOrders;
/**
* Gets the sort order for the specified driver comparator.
*
* @param id The id of the driver comparator for which the sort order is requested.
* @return The sort order of the specified driver comparator.
*/
public boolean[] getSortOrders(int id) {
return driverSortOrders[id];
}
/**
* Gets the comparator from this PlanNode.
* Gets the specified comparator from this PlanNode.
*
* @param id The ID of the requested comparator.
*
* @return The comparator.
* @return The specified comparator.
*/
public TypeComparatorFactory<?> getComparator() {
return comparator;
public TypeComparatorFactory<?> getComparator(int id) {
return comparators[id];
}
/**
* Sets the comparator for this PlanNode.
* Sets the specified comparator for this PlanNode.
*
* @param comparator The comparator to set.
* @param id The ID of the comparator to set.
*/
public void setComparator(TypeComparatorFactory<?> comparator) {
this.comparator = comparator;
public void setComparator(TypeComparatorFactory<?> comparator, int id) {
this.comparators[id] = comparator;
}
// --------------------------------------------------------------------------------------------
......
......@@ -770,10 +770,9 @@ public class NepheleJobGraphGenerator implements Visitor<PlanNode> {
// set the driver strategy
config.setDriverStrategy(ds);
if (node.getComparator() != null) {
config.setDriverComparator(node.getComparator(), 0);
for(int i=0;i<ds.getNumRequiredComparators();i++) {
config.setDriverComparator(node.getComparator(i), i);
}
// assign memory, file-handles, etc.
assignDriverResources(node, config);
return vertex;
......
......@@ -302,9 +302,9 @@ public abstract class GenericFlatTypePostPass<X, T extends AbstractSchema<X>> im
if (createUtilities) {
// parameterize the node's driver strategy
if (sn.getDriverStrategy().requiresComparator()) {
for(int i=0;i<sn.getDriverStrategy().getNumRequiredComparators();i++) {
try {
sn.setComparator(createComparator(sn.getKeys(), sn.getSortOrders(), schema));
sn.setComparator(createComparator(sn.getKeys(i), sn.getSortOrders(i), schema),i);
} catch (MissingFieldTypeInfoException e) {
throw new CompilerPostPassException("Could not set up runtime strategy for node '" +
optNode.getPactContract().getName() + "'. Missing type information for key field " +
......@@ -371,7 +371,7 @@ public abstract class GenericFlatTypePostPass<X, T extends AbstractSchema<X>> im
// parameterize the node's driver strategy
if (createUtilities) {
if (dn.getDriverStrategy().requiresComparator()) {
if (dn.getDriverStrategy().getNumRequiredComparators() > 0) {
// set the individual comparators
try {
dn.setComparator1(createComparator(dn.getKeysForInput1(), dn.getSortOrders(), schema1));
......
......@@ -161,11 +161,10 @@ public class JavaApiPostPass implements OptimizerPostPass {
SingleInputOperator<?, ?, ?> singleInputOperator = (SingleInputOperator<?, ?, ?>) sn.getOptimizerNode().getPactContract();
// parameterize the node's driver strategy
if (sn.getDriverStrategy().requiresComparator()) {
sn.setComparator(createComparator(singleInputOperator.getOperatorInfo().getInputType(), sn.getKeys(),
getSortOrders(sn.getKeys(), sn.getSortOrders())));
for(int i=0;i<sn.getDriverStrategy().getNumRequiredComparators();i++) {
sn.setComparator(createComparator(singleInputOperator.getOperatorInfo().getInputType(), sn.getKeys(i),
getSortOrders(sn.getKeys(i), sn.getSortOrders(i))), i);
}
// done, we can now propagate our info down
traverseChannel(sn.getInput());
......@@ -184,7 +183,7 @@ public class JavaApiPostPass implements OptimizerPostPass {
DualInputOperator<?, ?, ?, ?> dualInputOperator = (DualInputOperator<?, ?, ?, ?>) dn.getOptimizerNode().getPactContract();
// parameterize the node's driver strategy
if (dn.getDriverStrategy().requiresComparator()) {
if (dn.getDriverStrategy().getNumRequiredComparators() > 0) {
dn.setComparator1(createComparator(dualInputOperator.getOperatorInfo().getFirstInputType(), dn.getKeysForInput1(),
getSortOrders(dn.getKeysForInput1(), dn.getSortOrders())));
dn.setComparator2(createComparator(dualInputOperator.getOperatorInfo().getSecondInputType(), dn.getKeysForInput2(),
......
......@@ -95,7 +95,7 @@ public class GroupOrderTest extends CompilerTestBase {
FieldList local = new FieldList(2, 5);
Assert.assertEquals(ship, c.getShipStrategyKeys());
Assert.assertEquals(local, c.getLocalStrategyKeys());
Assert.assertTrue(c.getLocalStrategySortOrder()[0] == reducer.getSortOrders()[0]);
Assert.assertTrue(c.getLocalStrategySortOrder()[0] == reducer.getSortOrders(0)[0]);
// check that we indeed sort descending
Assert.assertTrue(c.getLocalStrategySortOrder()[1] == groupOrder.getFieldSortDirections()[0]);
......
......@@ -167,7 +167,7 @@ public class GroupReduceCompilationTest extends CompilerTestBase implements java
assertEquals(DriverStrategy.SORTED_GROUP_REDUCE, reduceNode.getDriverStrategy());
// check the keys
assertEquals(new FieldList(1), reduceNode.getKeys());
assertEquals(new FieldList(1), reduceNode.getKeys(0));
assertEquals(new FieldList(1), reduceNode.getInput().getLocalStrategyKeys());
// check DOP
......@@ -222,8 +222,9 @@ public class GroupReduceCompilationTest extends CompilerTestBase implements java
assertEquals(DriverStrategy.SORTED_GROUP_COMBINE, combineNode.getDriverStrategy());
// check the keys
assertEquals(new FieldList(1), reduceNode.getKeys());
assertEquals(new FieldList(1), combineNode.getKeys());
assertEquals(new FieldList(1), reduceNode.getKeys(0));
assertEquals(new FieldList(1), combineNode.getKeys(0));
assertEquals(new FieldList(1), combineNode.getKeys(1));
assertEquals(new FieldList(1), reduceNode.getInput().getLocalStrategyKeys());
// check DOP
......@@ -279,7 +280,7 @@ public class GroupReduceCompilationTest extends CompilerTestBase implements java
assertEquals(DriverStrategy.SORTED_GROUP_REDUCE, reduceNode.getDriverStrategy());
// check the keys
assertEquals(new FieldList(0), reduceNode.getKeys());
assertEquals(new FieldList(0), reduceNode.getKeys(0));
assertEquals(new FieldList(0), reduceNode.getInput().getLocalStrategyKeys());
// check DOP
......@@ -343,8 +344,9 @@ public class GroupReduceCompilationTest extends CompilerTestBase implements java
assertEquals(DriverStrategy.SORTED_GROUP_COMBINE, combineNode.getDriverStrategy());
// check the keys
assertEquals(new FieldList(0), reduceNode.getKeys());
assertEquals(new FieldList(0), combineNode.getKeys());
assertEquals(new FieldList(0), reduceNode.getKeys(0));
assertEquals(new FieldList(0), combineNode.getKeys(0));
assertEquals(new FieldList(0), combineNode.getKeys(1));
assertEquals(new FieldList(0), reduceNode.getInput().getLocalStrategyKeys());
// check DOP
......
......@@ -175,8 +175,8 @@ public class ReduceCompilationTest extends CompilerTestBase implements java.io.S
assertEquals(DriverStrategy.SORTED_PARTIAL_REDUCE, combineNode.getDriverStrategy());
// check the keys
assertEquals(new FieldList(1), reduceNode.getKeys());
assertEquals(new FieldList(1), combineNode.getKeys());
assertEquals(new FieldList(1), reduceNode.getKeys(0));
assertEquals(new FieldList(1), combineNode.getKeys(0));
assertEquals(new FieldList(1), reduceNode.getInput().getLocalStrategyKeys());
// check DOP
......@@ -239,8 +239,8 @@ public class ReduceCompilationTest extends CompilerTestBase implements java.io.S
assertEquals(DriverStrategy.SORTED_PARTIAL_REDUCE, combineNode.getDriverStrategy());
// check the keys
assertEquals(new FieldList(0), reduceNode.getKeys());
assertEquals(new FieldList(0), combineNode.getKeys());
assertEquals(new FieldList(0), reduceNode.getKeys(0));
assertEquals(new FieldList(0), combineNode.getKeys(0));
assertEquals(new FieldList(0), reduceNode.getInput().getLocalStrategyKeys());
// check DOP
......
......@@ -83,7 +83,7 @@ public class WorksetIterationsJavaApiCompilerTest extends CompilerTestBase {
// verify reducer
assertEquals(ShipStrategyType.PARTITION_HASH, worksetReducer.getInput().getShipStrategy());
assertEquals(new FieldList(1, 2), worksetReducer.getKeys());
assertEquals(new FieldList(1, 2), worksetReducer.getKeys(0));
// currently, the system may partition before or after the mapper
ShipStrategyType ss1 = deltaMapper.getInput().getShipStrategy();
......@@ -129,7 +129,7 @@ public class WorksetIterationsJavaApiCompilerTest extends CompilerTestBase {
// verify reducer
assertEquals(ShipStrategyType.PARTITION_HASH, worksetReducer.getInput().getShipStrategy());
assertEquals(new FieldList(1, 2), worksetReducer.getKeys());
assertEquals(new FieldList(1, 2), worksetReducer.getKeys(0));
// verify solution delta
assertEquals(2, joinWithSolutionSetNode.getOutgoingChannels().size());
......@@ -174,7 +174,7 @@ public class WorksetIterationsJavaApiCompilerTest extends CompilerTestBase {
// verify reducer
assertEquals(ShipStrategyType.FORWARD, worksetReducer.getInput().getShipStrategy());
assertEquals(new FieldList(1, 2), worksetReducer.getKeys());
assertEquals(new FieldList(1, 2), worksetReducer.getKeys(0));
// verify solution delta
......
......@@ -97,7 +97,7 @@ public class WorksetIterationsRecordApiCompilerTest extends CompilerTestBase {
// verify reducer
assertEquals(ShipStrategyType.PARTITION_HASH, worksetReducer.getInput().getShipStrategy());
assertEquals(list0, worksetReducer.getKeys());
assertEquals(list0, worksetReducer.getKeys(0));
// currently, the system may partition before or after the mapper
ShipStrategyType ss1 = deltaMapper.getInput().getShipStrategy();
......@@ -142,7 +142,7 @@ public class WorksetIterationsRecordApiCompilerTest extends CompilerTestBase {
// verify reducer
assertEquals(ShipStrategyType.PARTITION_HASH, worksetReducer.getInput().getShipStrategy());
assertEquals(list0, worksetReducer.getKeys());
assertEquals(list0, worksetReducer.getKeys(0));
// verify solution delta
......@@ -186,7 +186,7 @@ public class WorksetIterationsRecordApiCompilerTest extends CompilerTestBase {
// verify reducer
assertEquals(ShipStrategyType.FORWARD, worksetReducer.getInput().getShipStrategy());
assertEquals(list0, worksetReducer.getKeys());
assertEquals(list0, worksetReducer.getKeys(0));
// verify solution delta
......
......@@ -59,8 +59,8 @@ public abstract class AbstractCachedBuildSideMatchDriver<IT1, IT2, OT> extends M
TypeSerializer<IT1> serializer1 = this.taskContext.<IT1>getInputSerializer(0).getSerializer();
TypeSerializer<IT2> serializer2 = this.taskContext.<IT2>getInputSerializer(1).getSerializer();
TypeComparator<IT1> comparator1 = this.taskContext.getInputComparator(0);
TypeComparator<IT2> comparator2 = this.taskContext.getInputComparator(1);
TypeComparator<IT1> comparator1 = this.taskContext.getDriverComparator(0);
TypeComparator<IT2> comparator2 = this.taskContext.getDriverComparator(1);
MutableObjectIterator<IT1> input1 = this.taskContext.getInput(0);
MutableObjectIterator<IT2> input2 = this.taskContext.getInput(1);
......
......@@ -71,8 +71,8 @@ public class AllGroupReduceDriver<IT, OT> implements PactDriver<GroupReduceFunct
}
@Override
public boolean requiresComparatorOnInput() {
return false;
public int getNumberOfDriverComparators() {
return 0;
}
// --------------------------------------------------------------------------------------------
......
......@@ -70,8 +70,8 @@ public class AllReduceDriver<T> implements PactDriver<ReduceFunction<T>, T> {
}
@Override
public boolean requiresComparatorOnInput() {
return false;
public int getNumberOfDriverComparators() {
return 0;
}
// --------------------------------------------------------------------------------------------
......
......@@ -77,8 +77,8 @@ public class CoGroupDriver<IT1, IT2, OT> implements PactDriver<CoGroupFunction<I
@Override
public boolean requiresComparatorOnInput() {
return true;
public int getNumberOfDriverComparators() {
return 2;
}
......@@ -96,8 +96,8 @@ public class CoGroupDriver<IT1, IT2, OT> implements PactDriver<CoGroupFunction<I
// get the key positions and types
final TypeSerializer<IT1> serializer1 = this.taskContext.<IT1>getInputSerializer(0).getSerializer();
final TypeSerializer<IT2> serializer2 = this.taskContext.<IT2>getInputSerializer(1).getSerializer();
final TypeComparator<IT1> groupComparator1 = this.taskContext.getInputComparator(0);
final TypeComparator<IT2> groupComparator2 = this.taskContext.getInputComparator(1);
final TypeComparator<IT1> groupComparator1 = this.taskContext.getDriverComparator(0);
final TypeComparator<IT2> groupComparator2 = this.taskContext.getDriverComparator(1);
final TypePairComparatorFactory<IT1, IT2> pairComparatorFactory = config.getPairComparatorFactory(
this.taskContext.getUserCodeClassLoader());
......
......@@ -69,8 +69,8 @@ public class CoGroupWithSolutionSetFirstDriver<IT1, IT2, OT> implements Resettab
}
@Override
public boolean requiresComparatorOnInput() {
return true;
public int getNumberOfDriverComparators() {
return 1;
}
@Override
......@@ -103,7 +103,7 @@ public class CoGroupWithSolutionSetFirstDriver<IT1, IT2, OT> implements Resettab
TypeComparator<IT1> buildSideComparator = hashTable.getBuildSideComparator().duplicate();
probeSideSerializer = taskContext.<IT2>getInputSerializer(0).getSerializer();
probeSideComparator = taskContext.getInputComparator(0);
probeSideComparator = taskContext.getDriverComparator(0);
solutionSideRecord = buildSideSerializer.createInstance();
......
......@@ -68,8 +68,8 @@ public class CoGroupWithSolutionSetSecondDriver<IT1, IT2, OT> implements Resetta
}
@Override
public boolean requiresComparatorOnInput() {
return true;
public int getNumberOfDriverComparators() {
return 1;
}
@Override
......@@ -102,7 +102,7 @@ public class CoGroupWithSolutionSetSecondDriver<IT1, IT2, OT> implements Resetta
TypeComparator<IT2> buildSideComparator = hashTable.getBuildSideComparator().duplicate();
probeSideSerializer = taskContext.<IT1>getInputSerializer(0).getSerializer();
probeSideComparator = taskContext.getInputComparator(0);
probeSideComparator = taskContext.getDriverComparator(0);
solutionSideRecord = buildSideSerializer.createInstance();
......
......@@ -63,8 +63,8 @@ public class CollectorMapDriver<IT, OT> implements PactDriver<GenericCollectorMa
}
@Override
public boolean requiresComparatorOnInput() {
return false;
public int getNumberOfDriverComparators() {
return 0;
}
@Override
......
......@@ -88,8 +88,8 @@ public class CrossDriver<T1, T2, OT> implements PactDriver<CrossFunction<T1, T2,
@Override
public boolean requiresComparatorOnInput() {
return false;
public int getNumberOfDriverComparators() {
return 0;
}
......
......@@ -34,66 +34,66 @@ import org.apache.flink.runtime.operators.chaining.SynchronousChainedCombineDriv
*/
public enum DriverStrategy {
// no local strategy, as for sources and sinks
NONE(null, null, PIPELINED, false),
NONE(null, null, PIPELINED, 0),
// a unary no-op operator
UNARY_NO_OP(NoOpDriver.class, null, PIPELINED, PIPELINED, false),
UNARY_NO_OP(NoOpDriver.class, null, PIPELINED, PIPELINED, 0),
// a binary no-op operator. non implementation available
BINARY_NO_OP(null, null, PIPELINED, PIPELINED, false),
BINARY_NO_OP(null, null, PIPELINED, PIPELINED, 0),
// the old mapper
COLLECTOR_MAP(CollectorMapDriver.class, ChainedCollectorMapDriver.class, PIPELINED, false),
COLLECTOR_MAP(CollectorMapDriver.class, ChainedCollectorMapDriver.class, PIPELINED, 0),
// the proper mapper
MAP(MapDriver.class, ChainedMapDriver.class, PIPELINED, false),
MAP(MapDriver.class, ChainedMapDriver.class, PIPELINED, 0),
// the proper map partition
MAP_PARTITION(MapPartitionDriver.class, null, PIPELINED, false),
MAP_PARTITION(MapPartitionDriver.class, null, PIPELINED, 0),
// the flat mapper
FLAT_MAP(FlatMapDriver.class, ChainedFlatMapDriver.class, PIPELINED, false),
FLAT_MAP(FlatMapDriver.class, ChainedFlatMapDriver.class, PIPELINED, 0),
// group everything together into one group and apply the Reduce function
ALL_REDUCE(AllReduceDriver.class, null, PIPELINED, false),
ALL_REDUCE(AllReduceDriver.class, null, PIPELINED, 0),
// group everything together into one group and apply the GroupReduce function
ALL_GROUP_REDUCE(AllGroupReduceDriver.class, null, PIPELINED, false),
ALL_GROUP_REDUCE(AllGroupReduceDriver.class, null, PIPELINED, 0),
// group everything together into one group and apply the GroupReduce's combine function
ALL_GROUP_COMBINE(AllGroupReduceDriver.class, null, PIPELINED, false),
ALL_GROUP_COMBINE(AllGroupReduceDriver.class, null, PIPELINED, 0),
// grouping the inputs and apply the Reduce Function
SORTED_REDUCE(ReduceDriver.class, null, PIPELINED, true),
SORTED_REDUCE(ReduceDriver.class, null, PIPELINED, 1),
// sorted partial reduce is the combiner for the Reduce. same function, but potentially not fully sorted
SORTED_PARTIAL_REDUCE(ReduceCombineDriver.class, null, MATERIALIZING, true),
SORTED_PARTIAL_REDUCE(ReduceCombineDriver.class, null, MATERIALIZING, 1),
// grouping the inputs and apply the GroupReduce function
SORTED_GROUP_REDUCE(GroupReduceDriver.class, null, PIPELINED, true),
SORTED_GROUP_REDUCE(GroupReduceDriver.class, null, PIPELINED, 1),
// partially grouping inputs (best effort resulting possibly in duplicates --> combiner)
SORTED_GROUP_COMBINE(GroupReduceCombineDriver.class, SynchronousChainedCombineDriver.class, MATERIALIZING, true),
SORTED_GROUP_COMBINE(GroupReduceCombineDriver.class, SynchronousChainedCombineDriver.class, MATERIALIZING, 2),
// both inputs are merged, but materialized to the side for block-nested-loop-join among values with equal key
MERGE(MatchDriver.class, null, MATERIALIZING, MATERIALIZING, true),
MERGE(MatchDriver.class, null, MATERIALIZING, MATERIALIZING, 2),
// co-grouping inputs
CO_GROUP(CoGroupDriver.class, null, PIPELINED, PIPELINED, true),
CO_GROUP(CoGroupDriver.class, null, PIPELINED, PIPELINED, 2),
// the first input is build side, the second side is probe side of a hybrid hash table
HYBRIDHASH_BUILD_FIRST(MatchDriver.class, null, FULL_DAM, MATERIALIZING, true),
HYBRIDHASH_BUILD_FIRST(MatchDriver.class, null, FULL_DAM, MATERIALIZING, 2),
// the second input is build side, the first side is probe side of a hybrid hash table
HYBRIDHASH_BUILD_SECOND(MatchDriver.class, null, MATERIALIZING, FULL_DAM, true),
HYBRIDHASH_BUILD_SECOND(MatchDriver.class, null, MATERIALIZING, FULL_DAM, 2),
// a cached variant of HYBRIDHASH_BUILD_FIRST, that can only be used inside of iterations
HYBRIDHASH_BUILD_FIRST_CACHED(BuildFirstCachedMatchDriver.class, null, FULL_DAM, MATERIALIZING, true),
HYBRIDHASH_BUILD_FIRST_CACHED(BuildFirstCachedMatchDriver.class, null, FULL_DAM, MATERIALIZING, 2),
// cached variant of HYBRIDHASH_BUILD_SECOND, that can only be used inside of iterations
HYBRIDHASH_BUILD_SECOND_CACHED(BuildSecondCachedMatchDriver.class, null, MATERIALIZING, FULL_DAM, true),
HYBRIDHASH_BUILD_SECOND_CACHED(BuildSecondCachedMatchDriver.class, null, MATERIALIZING, FULL_DAM, 2),
// the second input is inner loop, the first input is outer loop and block-wise processed
NESTEDLOOP_BLOCKED_OUTER_FIRST(CrossDriver.class, null, MATERIALIZING, FULL_DAM, false),
NESTEDLOOP_BLOCKED_OUTER_FIRST(CrossDriver.class, null, MATERIALIZING, FULL_DAM, 0),
// the first input is inner loop, the second input is outer loop and block-wise processed
NESTEDLOOP_BLOCKED_OUTER_SECOND(CrossDriver.class, null, FULL_DAM, MATERIALIZING, false),
NESTEDLOOP_BLOCKED_OUTER_SECOND(CrossDriver.class, null, FULL_DAM, MATERIALIZING, 0),
// the second input is inner loop, the first input is outer loop and stream-processed
NESTEDLOOP_STREAMED_OUTER_FIRST(CrossDriver.class, null, PIPELINED, FULL_DAM, false),
NESTEDLOOP_STREAMED_OUTER_FIRST(CrossDriver.class, null, PIPELINED, FULL_DAM, 0),
// the first input is inner loop, the second input is outer loop and stream-processed
NESTEDLOOP_STREAMED_OUTER_SECOND(CrossDriver.class, null, FULL_DAM, PIPELINED, false),
NESTEDLOOP_STREAMED_OUTER_SECOND(CrossDriver.class, null, FULL_DAM, PIPELINED, 0),
// union utility op. unions happen implicitly on the network layer (in the readers) when bundeling streams
UNION(null, null, FULL_DAM, FULL_DAM, false);
UNION(null, null, FULL_DAM, FULL_DAM, 0);
// explicit binary union between a streamed and a cached input
// UNION_WITH_CACHED(UnionWithTempOperator.class, null, FULL_DAM, PIPELINED, false);
......@@ -108,35 +108,35 @@ public enum DriverStrategy {
private final int numInputs;
private final boolean requiresComparator;
private final int numRequiredComparators;
@SuppressWarnings("unchecked")
private DriverStrategy(
@SuppressWarnings("rawtypes") Class<? extends PactDriver> driverClass,
@SuppressWarnings("rawtypes") Class<? extends ChainedDriver> pushChainDriverClass,
DamBehavior dam, boolean comparator)
DamBehavior dam, int numComparator)
{
this.driverClass = (Class<? extends PactDriver<?, ?>>) driverClass;
this.pushChainDriver = (Class<? extends ChainedDriver<?, ?>>) pushChainDriverClass;
this.numInputs = 1;
this.dam1 = dam;
this.dam2 = null;
this.requiresComparator = comparator;
this.numRequiredComparators = numComparator;
}
@SuppressWarnings("unchecked")
private DriverStrategy(
@SuppressWarnings("rawtypes") Class<? extends PactDriver> driverClass,
@SuppressWarnings("rawtypes") Class<? extends ChainedDriver> pushChainDriverClass,
DamBehavior firstDam, DamBehavior secondDam, boolean comparator)
DamBehavior firstDam, DamBehavior secondDam, int numComparator)
{
this.driverClass = (Class<? extends PactDriver<?, ?>>) driverClass;
this.pushChainDriver = (Class<? extends ChainedDriver<?, ?>>) pushChainDriverClass;
this.numInputs = 2;
this.dam1 = firstDam;
this.dam2 = secondDam;
this.requiresComparator = comparator;
this.numRequiredComparators = numComparator;
}
// --------------------------------------------------------------------------------------------
......@@ -180,7 +180,7 @@ public enum DriverStrategy {
return this.dam1.isMaterializing() || (this.dam2 != null && this.dam2.isMaterializing());
}
public boolean requiresComparator() {
return this.requiresComparator;
public int getNumRequiredComparators() {
return this.numRequiredComparators;
}
}
......@@ -62,8 +62,8 @@ public class FlatMapDriver<IT, OT> implements PactDriver<FlatMapFunction<IT, OT>
}
@Override
public boolean requiresComparatorOnInput() {
return false;
public int getNumberOfDriverComparators() {
return 0;
}
@Override
......
......@@ -60,7 +60,9 @@ public class GroupReduceCombineDriver<T> implements PactDriver<FlatCombineFuncti
private TypeSerializer<T> serializer;
private TypeComparator<T> comparator;
private TypeComparator<T> sortingComparator;
private TypeComparator<T> groupingComparator;
private QuickSort sortAlgo = new QuickSort();
......@@ -82,7 +84,7 @@ public class GroupReduceCombineDriver<T> implements PactDriver<FlatCombineFuncti
public int getNumberOfInputs() {
return 1;
}
@Override
public Class<FlatCombineFunction<T>> getStubType() {
@SuppressWarnings("unchecked")
......@@ -91,8 +93,8 @@ public class GroupReduceCombineDriver<T> implements PactDriver<FlatCombineFuncti
}
@Override
public boolean requiresComparatorOnInput() {
return true;
public int getNumberOfDriverComparators() {
return 2;
}
@Override
......@@ -107,7 +109,8 @@ public class GroupReduceCombineDriver<T> implements PactDriver<FlatCombineFuncti
final TypeSerializerFactory<T> serializerFactory = this.taskContext.getInputSerializer(0);
this.serializer = serializerFactory.getSerializer();
this.comparator = this.taskContext.getInputComparator(0);
this.sortingComparator = this.taskContext.getDriverComparator(0);
this.groupingComparator = this.taskContext.getDriverComparator(1);
this.combiner = this.taskContext.getStub();
this.output = this.taskContext.getOutputCollector();
......@@ -115,12 +118,12 @@ public class GroupReduceCombineDriver<T> implements PactDriver<FlatCombineFuncti
numMemoryPages);
// instantiate a fix-length in-place sorter, if possible, otherwise the out-of-place sorter
if (this.comparator.supportsSerializationWithKeyNormalization() &&
if (this.sortingComparator.supportsSerializationWithKeyNormalization() &&
this.serializer.getLength() > 0 && this.serializer.getLength() <= THRESHOLD_FOR_IN_PLACE_SORTING)
{
this.sorter = new FixedLengthRecordSorter<T>(this.serializer, this.comparator, memory);
this.sorter = new FixedLengthRecordSorter<T>(this.serializer, this.sortingComparator, memory);
} else {
this.sorter = new NormalizedKeySorter<T>(this.serializer, this.comparator.duplicate(), memory);
this.sorter = new NormalizedKeySorter<T>(this.serializer, this.sortingComparator.duplicate(), memory);
}
}
......@@ -163,7 +166,7 @@ public class GroupReduceCombineDriver<T> implements PactDriver<FlatCombineFuncti
this.sortAlgo.sort(sorter);
final KeyGroupedIterator<T> keyIter = new KeyGroupedIterator<T>(sorter.getIterator(), this.serializer,
this.comparator);
this.groupingComparator);
final FlatCombineFunction<T> combiner = this.combiner;
final Collector<T> output = this.output;
......@@ -177,12 +180,16 @@ public class GroupReduceCombineDriver<T> implements PactDriver<FlatCombineFuncti
@Override
public void cleanup() throws Exception {
this.memManager.release(this.sorter.dispose());
if(this.sorter != null) {
this.memManager.release(this.sorter.dispose());
}
}
@Override
public void cancel() {
this.running = false;
this.memManager.release(this.sorter.dispose());
if(this.sorter != null) {
this.memManager.release(this.sorter.dispose());
}
}
}
......@@ -74,8 +74,8 @@ public class GroupReduceDriver<IT, OT> implements PactDriver<GroupReduceFunction
}
@Override
public boolean requiresComparatorOnInput() {
return true;
public int getNumberOfDriverComparators() {
return 1;
}
// --------------------------------------------------------------------------------------------
......@@ -87,7 +87,7 @@ public class GroupReduceDriver<IT, OT> implements PactDriver<GroupReduceFunction
throw new Exception("Unrecognized driver strategy for GroupReduce driver: " + config.getDriverStrategy().name());
}
this.serializer = this.taskContext.<IT>getInputSerializer(0).getSerializer();
this.comparator = this.taskContext.getInputComparator(0);
this.comparator = this.taskContext.getDriverComparator(0);
this.input = this.taskContext.getInput(0);
}
......
......@@ -68,8 +68,8 @@ public class JoinWithSolutionSetFirstDriver<IT1, IT2, OT> implements ResettableP
}
@Override
public boolean requiresComparatorOnInput() {
return true;
public int getNumberOfDriverComparators() {
return 1;
}
@Override
......
......@@ -68,8 +68,8 @@ public class JoinWithSolutionSetSecondDriver<IT1, IT2, OT> implements Resettable
}
@Override
public boolean requiresComparatorOnInput() {
return true;
public int getNumberOfDriverComparators() {
return 1;
}
@Override
......
......@@ -62,8 +62,8 @@ public class MapDriver<IT, OT> implements PactDriver<MapFunction<IT, OT>, OT> {
}
@Override
public boolean requiresComparatorOnInput() {
return false;
public int getNumberOfDriverComparators() {
return 0;
}
@Override
......
......@@ -59,8 +59,8 @@ public class MapPartitionDriver<IT, OT> implements PactDriver<MapPartitionFuncti
}
@Override
public boolean requiresComparatorOnInput() {
return false;
public int getNumberOfDriverComparators() {
return 0;
}
@Override
......
......@@ -75,8 +75,8 @@ public class MatchDriver<IT1, IT2, OT> implements PactDriver<FlatJoinFunction<IT
}
@Override
public boolean requiresComparatorOnInput() {
return true;
public int getNumberOfDriverComparators() {
return 2;
}
@Override
......@@ -100,8 +100,8 @@ public class MatchDriver<IT1, IT2, OT> implements PactDriver<FlatJoinFunction<IT
// get the key positions and types
final TypeSerializer<IT1> serializer1 = this.taskContext.<IT1>getInputSerializer(0).getSerializer();
final TypeSerializer<IT2> serializer2 = this.taskContext.<IT2>getInputSerializer(1).getSerializer();
final TypeComparator<IT1> comparator1 = this.taskContext.getInputComparator(0);
final TypeComparator<IT2> comparator2 = this.taskContext.getInputComparator(1);
final TypeComparator<IT1> comparator1 = this.taskContext.getDriverComparator(0);
final TypeComparator<IT2> comparator2 = this.taskContext.getDriverComparator(1);
final TypePairComparatorFactory<IT1, IT2> pairComparatorFactory = config.getPairComparatorFactory(
this.taskContext.getUserCodeClassLoader());
......
......@@ -52,8 +52,8 @@ public class NoOpDriver<T> implements PactDriver<AbstractRichFunction, T> {
}
@Override
public boolean requiresComparatorOnInput() {
return false;
public int getNumberOfDriverComparators() {
return 0;
}
@Override
......
......@@ -44,19 +44,19 @@ public interface PactDriver<S extends Function, OT> {
int getNumberOfInputs();
/**
* Gets the class of the stub type that is run by this task. For example, a <tt>MapTask</tt> should return
* <code>MapFunction.class</code>.
* Gets the number of comparators required for this driver.
*
* @return The class of the stub type run by the task.
* @return The number of comparators required for this driver.
*/
Class<S> getStubType();
int getNumberOfDriverComparators();
/**
* Flag indicating whether the inputs require always comparators or not.
* Gets the class of the stub type that is run by this task. For example, a <tt>MapTask</tt> should return
* <code>MapFunction.class</code>.
*
* @return True, if the initialization should look for and create comparators, false otherwise.
* @return The class of the stub type run by the task.
*/
boolean requiresComparatorOnInput();
Class<S> getStubType();
/**
* This method is called before the user code is opened. An exception thrown by this method
......
......@@ -54,7 +54,7 @@ public interface PactTaskContext<S, OT> {
<X> TypeSerializerFactory<X> getInputSerializer(int index);
<X> TypeComparator<X> getInputComparator(int index);
<X> TypeComparator<X> getDriverComparator(int index);
S getStub();
......
......@@ -93,8 +93,8 @@ public class ReduceCombineDriver<T> implements PactDriver<ReduceFunction<T>, T>
}
@Override
public boolean requiresComparatorOnInput() {
return true;
public int getNumberOfDriverComparators() {
return 1;
}
@Override
......@@ -109,7 +109,7 @@ public class ReduceCombineDriver<T> implements PactDriver<ReduceFunction<T>, T>
// instantiate the serializer / comparator
final TypeSerializerFactory<T> serializerFactory = this.taskContext.getInputSerializer(0);
this.comparator = this.taskContext.getInputComparator(0);
this.comparator = this.taskContext.getDriverComparator(0);
this.serializer = serializerFactory.getSerializer();
this.reducer = this.taskContext.getStub();
this.output = this.taskContext.getOutputCollector();
......
......@@ -73,8 +73,8 @@ public class ReduceDriver<T> implements PactDriver<ReduceFunction<T>, T> {
}
@Override
public boolean requiresComparatorOnInput() {
return true;
public int getNumberOfDriverComparators() {
return 1;
}
// --------------------------------------------------------------------------------------------
......@@ -86,7 +86,7 @@ public class ReduceDriver<T> implements PactDriver<ReduceFunction<T>, T> {
throw new Exception("Unrecognized driver strategy for Reduce driver: " + config.getDriverStrategy().name());
}
this.serializer = this.taskContext.<T>getInputSerializer(0).getSerializer();
this.comparator = this.taskContext.getInputComparator(0);
this.comparator = this.taskContext.getDriverComparator(0);
this.input = this.taskContext.getInput(0);
}
......
......@@ -299,9 +299,10 @@ public class RegularPactTask<S extends Function, OT> extends AbstractInvokable i
// the local processing includes building the dams / caches
try {
int numInputs = driver.getNumberOfInputs();
int numComparators = driver.getNumberOfDriverComparators();
int numBroadcastInputs = this.config.getNumBroadcastInputs();
initInputsSerializersAndComparators(numInputs);
initInputsSerializersAndComparators(numInputs, numComparators);
initBroadcastInputsSerializers(numBroadcastInputs);
// set the iterative status for inputs and broadcast inputs
......@@ -781,23 +782,27 @@ public class RegularPactTask<S extends Function, OT> extends AbstractInvokable i
/**
* Creates all the serializers and comparators.
*/
protected void initInputsSerializersAndComparators(int numInputs) throws Exception {
protected void initInputsSerializersAndComparators(int numInputs, int numComparators) throws Exception {
this.inputSerializers = new TypeSerializerFactory<?>[numInputs];
this.inputComparators = this.driver.requiresComparatorOnInput() ? new TypeComparator[numInputs] : null;
this.inputComparators = numComparators > 0 ? new TypeComparator[numComparators] : null;
this.inputIterators = new MutableObjectIterator[numInputs];
// ---------------- create the input serializers ---------------------
for (int i = 0; i < numInputs; i++) {
// ---------------- create the serializer first ---------------------
final TypeSerializerFactory<?> serializerFactory = this.config.getInputSerializer(i, this.userCodeClassLoader);
this.inputSerializers[i] = serializerFactory;
// ---------------- create the driver's comparator ---------------------
this.inputIterators[i] = createInputIterator(this.inputReaders[i], this.inputSerializers[i]);
}
// ---------------- create the driver's comparators ---------------------
for (int i = 0; i < numComparators; i++) {
if (this.inputComparators != null) {
final TypeComparatorFactory<?> comparatorFactory = this.config.getDriverComparator(i, this.userCodeClassLoader);
this.inputComparators[i] = comparatorFactory.createComparator();
}
this.inputIterators[i] = createInputIterator(this.inputReaders[i], this.inputSerializers[i]);
}
}
......@@ -1157,11 +1162,11 @@ public class RegularPactTask<S extends Function, OT> extends AbstractInvokable i
@Override
public <X> TypeComparator<X> getInputComparator(int index) {
public <X> TypeComparator<X> getDriverComparator(int index) {
if (this.inputComparators == null) {
throw new IllegalStateException("Comparators have not been created!");
}
else if (index < 0 || index >= this.driver.getNumberOfInputs()) {
else if (index < 0 || index >= this.driver.getNumberOfDriverComparators()) {
throw new IndexOutOfBoundsException();
}
......
......@@ -56,7 +56,9 @@ public class SynchronousChainedCombineDriver<T> extends ChainedDriver<T, T> {
private TypeSerializer<T> serializer;
private TypeComparator<T> comparator;
private TypeComparator<T> sortingComparator;
private TypeComparator<T> groupingComparator;
private AbstractInvokable parent;
......@@ -92,19 +94,21 @@ public class SynchronousChainedCombineDriver<T> extends ChainedDriver<T, T> {
// instantiate the serializer / comparator
final TypeSerializerFactory<T> serializerFactory = this.config.getInputSerializer(0, this.userCodeClassLoader);
final TypeComparatorFactory<T> comparatorFactory = this.config.getDriverComparator(0, this.userCodeClassLoader);
final TypeComparatorFactory<T> sortingComparatorFactory = this.config.getDriverComparator(0, this.userCodeClassLoader);
final TypeComparatorFactory<T> groupingComparatorFactory = this.config.getDriverComparator(1, this.userCodeClassLoader);
this.serializer = serializerFactory.getSerializer();
this.comparator = comparatorFactory.createComparator();
this.sortingComparator = sortingComparatorFactory.createComparator();
this.groupingComparator = groupingComparatorFactory.createComparator();
final List<MemorySegment> memory = this.memManager.allocatePages(this.parent, numMemoryPages);
// instantiate a fix-length in-place sorter, if possible, otherwise the out-of-place sorter
if (this.comparator.supportsSerializationWithKeyNormalization() &&
if (this.sortingComparator.supportsSerializationWithKeyNormalization() &&
this.serializer.getLength() > 0 && this.serializer.getLength() <= THRESHOLD_FOR_IN_PLACE_SORTING)
{
this.sorter = new FixedLengthRecordSorter<T>(this.serializer, this.comparator, memory);
this.sorter = new FixedLengthRecordSorter<T>(this.serializer, this.sortingComparator, memory);
} else {
this.sorter = new NormalizedKeySorter<T>(this.serializer, this.comparator.duplicate(), memory);
this.sorter = new NormalizedKeySorter<T>(this.serializer, this.sortingComparator.duplicate(), memory);
}
}
......@@ -183,7 +187,7 @@ public class SynchronousChainedCombineDriver<T> extends ChainedDriver<T, T> {
this.sortAlgo.sort(sorter);
// run the combiner
final KeyGroupedIterator<T> keyIter = new KeyGroupedIterator<T>(sorter.getIterator(), this.serializer,
this.comparator);
this.groupingComparator);
// cache references on the stack
final FlatCombineFunction<T> stub = this.combiner;
......
......@@ -70,8 +70,8 @@ public class CachedMatchTaskTest extends DriverTestBase<FlatJoinFunction<Record,
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED);
......@@ -101,8 +101,8 @@ public class CachedMatchTaskTest extends DriverTestBase<FlatJoinFunction<Record,
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED);
......@@ -132,8 +132,8 @@ public class CachedMatchTaskTest extends DriverTestBase<FlatJoinFunction<Record,
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED);
......@@ -163,8 +163,8 @@ public class CachedMatchTaskTest extends DriverTestBase<FlatJoinFunction<Record,
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED);
......@@ -194,8 +194,8 @@ public class CachedMatchTaskTest extends DriverTestBase<FlatJoinFunction<Record,
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED);
......@@ -225,8 +225,8 @@ public class CachedMatchTaskTest extends DriverTestBase<FlatJoinFunction<Record,
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(new NirvanaOutputList());
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST_CACHED);
......@@ -255,8 +255,8 @@ public class CachedMatchTaskTest extends DriverTestBase<FlatJoinFunction<Record,
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(new NirvanaOutputList());
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED);
......@@ -283,8 +283,8 @@ public class CachedMatchTaskTest extends DriverTestBase<FlatJoinFunction<Record,
addInput(new DelayingInfinitiveInputIterator(100));
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
......@@ -330,8 +330,8 @@ public class CachedMatchTaskTest extends DriverTestBase<FlatJoinFunction<Record,
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInput(new DelayingInfinitiveInputIterator(100));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(new NirvanaOutputList());
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND_CACHED);
......@@ -374,8 +374,8 @@ public class CachedMatchTaskTest extends DriverTestBase<FlatJoinFunction<Record,
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(new NirvanaOutputList());
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST);
......@@ -418,8 +418,8 @@ public class CachedMatchTaskTest extends DriverTestBase<FlatJoinFunction<Record,
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(new NirvanaOutputList());
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND);
......
......@@ -62,8 +62,8 @@ public class CoGroupTaskExternalITCase extends DriverTestBase<CoGroupFunction<Re
(keyCnt1 > keyCnt2 ? (keyCnt1 - keyCnt2) * valCnt1 : (keyCnt2 - keyCnt1) * valCnt2);
setOutput(this.output);
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
......
......@@ -68,8 +68,8 @@ public class CoGroupTaskTest extends DriverTestBase<CoGroupFunction<Record, Reco
(keyCnt1 > keyCnt2 ? (keyCnt1 - keyCnt2) * valCnt1 : (keyCnt2 - keyCnt1) * valCnt2);
setOutput(this.output);
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
......@@ -99,8 +99,8 @@ public class CoGroupTaskTest extends DriverTestBase<CoGroupFunction<Record, Reco
(keyCnt1 > keyCnt2 ? (keyCnt1 - keyCnt2) * valCnt1 : (keyCnt2 - keyCnt1) * valCnt2);
setOutput(this.output);
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
......@@ -130,8 +130,8 @@ public class CoGroupTaskTest extends DriverTestBase<CoGroupFunction<Record, Reco
(keyCnt1 > keyCnt2 ? (keyCnt1 - keyCnt2) * valCnt1 : (keyCnt2 - keyCnt1) * valCnt2);
setOutput(this.output);
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
......@@ -161,8 +161,8 @@ public class CoGroupTaskTest extends DriverTestBase<CoGroupFunction<Record, Reco
(keyCnt1 > keyCnt2 ? (keyCnt1 - keyCnt2) * valCnt1 : (keyCnt2 - keyCnt1) * valCnt2);
setOutput(this.output);
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
......@@ -195,8 +195,8 @@ public class CoGroupTaskTest extends DriverTestBase<CoGroupFunction<Record, Reco
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, true));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, true));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
......@@ -225,8 +225,8 @@ public class CoGroupTaskTest extends DriverTestBase<CoGroupFunction<Record, Reco
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, true));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, true));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
......@@ -251,8 +251,8 @@ public class CoGroupTaskTest extends DriverTestBase<CoGroupFunction<Record, Reco
setOutput(this.output);
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
......@@ -302,8 +302,8 @@ public class CoGroupTaskTest extends DriverTestBase<CoGroupFunction<Record, Reco
setOutput(this.output);
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
......@@ -354,8 +354,8 @@ public class CoGroupTaskTest extends DriverTestBase<CoGroupFunction<Record, Reco
setOutput(this.output);
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
......
......@@ -59,7 +59,8 @@ public class CombineTaskExternalITCase extends DriverTestBase<RichGroupReduceFun
final int valCnt = 8;
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInputComparator(this.comparator);
addDriverComparator(this.comparator);
addDriverComparator(this.comparator);
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
......@@ -112,7 +113,8 @@ public class CombineTaskExternalITCase extends DriverTestBase<RichGroupReduceFun
final int valCnt = 8;
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInputComparator(this.comparator);
addDriverComparator(this.comparator);
addDriverComparator(this.comparator);
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
......
......@@ -61,7 +61,8 @@ public class CombineTaskTest extends DriverTestBase<RichGroupReduceFunction<Reco
int valCnt = 20;
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInputComparator(this.comparator);
addDriverComparator(this.comparator);
addDriverComparator(this.comparator);
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
......@@ -97,7 +98,8 @@ public class CombineTaskTest extends DriverTestBase<RichGroupReduceFunction<Reco
int valCnt = 20;
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInputComparator(this.comparator);
addDriverComparator(this.comparator);
addDriverComparator(this.comparator);
setOutput(new DiscardingOutputCollector<Record>());
getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
......@@ -121,7 +123,8 @@ public class CombineTaskTest extends DriverTestBase<RichGroupReduceFunction<Reco
public void testCancelCombineTaskSorting()
{
addInput(new DelayingInfinitiveInputIterator(100));
addInputComparator(this.comparator);
addDriverComparator(this.comparator);
addDriverComparator(this.comparator);
setOutput(new DiscardingOutputCollector<Record>());
getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
......
......@@ -71,8 +71,8 @@ public class MatchTaskExternalITCase extends DriverTestBase<FlatJoinFunction<Rec
final int expCnt = valCnt1*valCnt2*Math.min(keyCnt1, keyCnt2);
setOutput(this.output);
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.MERGE);
getTaskConfig().setRelativeMemoryDriver(bnljn_frac);
......@@ -104,8 +104,8 @@ public class MatchTaskExternalITCase extends DriverTestBase<FlatJoinFunction<Rec
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(this.output);
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST);
......@@ -135,8 +135,8 @@ public class MatchTaskExternalITCase extends DriverTestBase<FlatJoinFunction<Rec
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(this.output);
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND);
......
......@@ -81,8 +81,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
final int valCnt2 = 2;
setOutput(this.outList);
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.MERGE);
getTaskConfig().setRelativeMemoryDriver(bnljn_frac);
......@@ -115,8 +115,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
int valCnt2 = 1;
setOutput(this.outList);
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.MERGE);
getTaskConfig().setRelativeMemoryDriver(bnljn_frac);
......@@ -151,8 +151,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
int valCnt2 = 20;
setOutput(this.outList);
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.MERGE);
getTaskConfig().setRelativeMemoryDriver(bnljn_frac);
......@@ -187,8 +187,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
int valCnt2 = 1;
setOutput(this.outList);
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.MERGE);
getTaskConfig().setRelativeMemoryDriver(bnljn_frac);
......@@ -223,8 +223,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
int valCnt2 = 20;
setOutput(this.outList);
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.MERGE);
getTaskConfig().setRelativeMemoryDriver(bnljn_frac);
......@@ -259,8 +259,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
int valCnt2 = 20;
setOutput(this.outList);
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.MERGE);
getTaskConfig().setRelativeMemoryDriver(bnljn_frac);
......@@ -295,8 +295,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
int valCnt2 = 20;
setOutput(this.outList);
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.MERGE);
getTaskConfig().setRelativeMemoryDriver(bnljn_frac);
......@@ -330,8 +330,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
int valCnt2 = 20;
setOutput(this.outList);
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.MERGE);
getTaskConfig().setRelativeMemoryDriver(bnljn_frac);
......@@ -366,8 +366,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
int valCnt2 = 20;
setOutput(new NirvanaOutputList());
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.MERGE);
getTaskConfig().setRelativeMemoryDriver(bnljn_frac);
......@@ -395,8 +395,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
int valCnt = 20;
setOutput(new NirvanaOutputList());
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.MERGE);
getTaskConfig().setRelativeMemoryDriver(bnljn_frac);
......@@ -446,8 +446,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
int valCnt = 20;
setOutput(new NirvanaOutputList());
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.MERGE);
getTaskConfig().setRelativeMemoryDriver(bnljn_frac);
......@@ -497,8 +497,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
int valCnt = 20;
setOutput(new NirvanaOutputList());
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
getTaskConfig().setDriverStrategy(DriverStrategy.MERGE);
getTaskConfig().setRelativeMemoryDriver(bnljn_frac);
......@@ -547,8 +547,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST);
......@@ -578,8 +578,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND);
......@@ -609,8 +609,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST);
......@@ -640,8 +640,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND);
......@@ -671,8 +671,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST);
......@@ -702,8 +702,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(new NirvanaOutputList());
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST);
......@@ -732,8 +732,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
addInput(new UniformRecordGenerator(keyCnt1, valCnt1, false));
addInput(new UniformRecordGenerator(keyCnt2, valCnt2, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(new NirvanaOutputList());
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND);
......@@ -760,8 +760,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
addInput(new DelayingInfinitiveInputIterator(100));
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
......@@ -807,8 +807,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInput(new DelayingInfinitiveInputIterator(100));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(new NirvanaOutputList());
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND);
......@@ -851,8 +851,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(new NirvanaOutputList());
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_FIRST);
......@@ -895,8 +895,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInput(new UniformRecordGenerator(keyCnt, valCnt, false));
addInputComparator(this.comparator1);
addInputComparator(this.comparator2);
addDriverComparator(this.comparator1);
addDriverComparator(this.comparator2);
getTaskConfig().setDriverPairComparator(RecordPairComparatorFactory.get());
setOutput(new NirvanaOutputList());
getTaskConfig().setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND);
......
......@@ -60,7 +60,7 @@ public class ReduceTaskExternalITCase extends DriverTestBase<RichGroupReduceFunc
setNumFileHandlesForSort(2);
addInputComparator(this.comparator);
addDriverComparator(this.comparator);
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE);
......@@ -92,7 +92,7 @@ public class ReduceTaskExternalITCase extends DriverTestBase<RichGroupReduceFunc
setNumFileHandlesForSort(2);
addInputComparator(this.comparator);
addDriverComparator(this.comparator);
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE);
......@@ -123,7 +123,7 @@ public class ReduceTaskExternalITCase extends DriverTestBase<RichGroupReduceFunc
final int keyCnt = 8192;
final int valCnt = 8;
addInputComparator(this.comparator);
addDriverComparator(this.comparator);
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE);
......@@ -169,7 +169,7 @@ public class ReduceTaskExternalITCase extends DriverTestBase<RichGroupReduceFunc
int keyCnt = 32768;
int valCnt = 8;
addInputComparator(this.comparator);
addDriverComparator(this.comparator);
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE);
......
......@@ -61,7 +61,7 @@ public class ReduceTaskTest extends DriverTestBase<RichGroupReduceFunction<Recor
final int keyCnt = 100;
final int valCnt = 20;
addInputComparator(this.comparator);
addDriverComparator(this.comparator);
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE);
......@@ -91,7 +91,7 @@ public class ReduceTaskTest extends DriverTestBase<RichGroupReduceFunction<Recor
final int valCnt = 20;
addInput(new UniformRecordGenerator(keyCnt, valCnt, true));
addInputComparator(this.comparator);
addDriverComparator(this.comparator);
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE);
......@@ -118,7 +118,7 @@ public class ReduceTaskTest extends DriverTestBase<RichGroupReduceFunction<Recor
final int keyCnt = 100;
final int valCnt = 20;
addInputComparator(this.comparator);
addDriverComparator(this.comparator);
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE);
......@@ -163,7 +163,7 @@ public class ReduceTaskTest extends DriverTestBase<RichGroupReduceFunction<Recor
final int valCnt = 20;
addInput(new UniformRecordGenerator(keyCnt, valCnt, true));
addInputComparator(this.comparator);
addDriverComparator(this.comparator);
setOutput(this.outList);
getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE);
......@@ -185,7 +185,7 @@ public class ReduceTaskTest extends DriverTestBase<RichGroupReduceFunction<Recor
@Test
public void testCancelReduceTaskWhileSorting()
{
addInputComparator(this.comparator);
addDriverComparator(this.comparator);
setOutput(new NirvanaOutputList());
getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE);
......@@ -233,7 +233,7 @@ public class ReduceTaskTest extends DriverTestBase<RichGroupReduceFunction<Recor
final int valCnt = 2;
addInput(new UniformRecordGenerator(keyCnt, valCnt, true));
addInputComparator(this.comparator);
addDriverComparator(this.comparator);
setOutput(new NirvanaOutputList());
getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_REDUCE);
......
......@@ -85,6 +85,7 @@ public class ChainTaskTest extends TaskTestBase {
// driver
combineConfig.setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
combineConfig.setDriverComparator(compFact, 0);
combineConfig.setDriverComparator(compFact, 1);
combineConfig.setRelativeMemoryDriver(memoryFraction);
// udf
......
......@@ -171,7 +171,7 @@ public class TestTaskContext<S, T> implements PactTaskContext<S, T> {
@Override
@SuppressWarnings("unchecked")
public <X> TypeComparator<X> getInputComparator(int index) {
public <X> TypeComparator<X> getDriverComparator(int index) {
switch (index) {
case 0:
return (TypeComparator<X>) this.comparator1;
......
......@@ -118,7 +118,7 @@ public class DriverTestBase<S extends Function> implements PactTaskContext<S, Re
this.inputs.add(null);
}
public void addInputComparator(RecordComparator comparator) {
public void addDriverComparator(RecordComparator comparator) {
this.comparators.add(comparator);
}
......@@ -283,7 +283,7 @@ public class DriverTestBase<S extends Function> implements PactTaskContext<S, Re
}
@Override
public <X> TypeComparator<X> getInputComparator(int index) {
public <X> TypeComparator<X> getDriverComparator(int index) {
@SuppressWarnings("unchecked")
TypeComparator<X> comparator = (TypeComparator<X>) this.comparators.get(index);
return comparator;
......
......@@ -111,15 +111,16 @@ public class KMeansSingleStepTest extends CompilerTestBase {
assertEquals(DriverStrategy.SORTED_GROUP_COMBINE, combiner.getDriverStrategy());
assertNull(combiner.getInput().getLocalStrategyKeys());
assertNull(combiner.getInput().getLocalStrategySortOrder());
assertEquals(set0, combiner.getKeys());
assertEquals(set0, combiner.getKeys(0));
assertEquals(set0, combiner.getKeys(1));
// check the reducer
assertEquals(ShipStrategyType.PARTITION_HASH, reducer.getInput().getShipStrategy());
assertEquals(LocalStrategy.COMBININGSORT, reducer.getInput().getLocalStrategy());
assertEquals(DriverStrategy.SORTED_GROUP_REDUCE, reducer.getDriverStrategy());
assertEquals(set0, reducer.getKeys());
assertEquals(set0, reducer.getKeys(0));
assertEquals(set0, reducer.getInput().getLocalStrategyKeys());
assertTrue(Arrays.equals(reducer.getInput().getLocalStrategySortOrder(), reducer.getSortOrders()));
assertTrue(Arrays.equals(reducer.getInput().getLocalStrategySortOrder(), reducer.getSortOrders(0)));
// check the sink
assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
......
......@@ -290,8 +290,8 @@ public class RelationalQueryCompilerTest extends CompilerTestBase {
// local strategy keys
Assert.assertEquals(set01, reducer.getInput().getLocalStrategyKeys());
Assert.assertEquals(set01, reducer.getKeys());
Assert.assertTrue(Arrays.equals(reducer.getInput().getLocalStrategySortOrder(), reducer.getSortOrders()));
Assert.assertEquals(set01, reducer.getKeys(0));
Assert.assertTrue(Arrays.equals(reducer.getInput().getLocalStrategySortOrder(), reducer.getSortOrders(0)));
return true;
} else {
return false;
......@@ -314,8 +314,8 @@ public class RelationalQueryCompilerTest extends CompilerTestBase {
Assert.assertEquals(set0, join.getInput2().getLocalStrategyKeys());
Assert.assertTrue(Arrays.equals(join.getInput1().getLocalStrategySortOrder(), join.getInput2().getLocalStrategySortOrder()));
Assert.assertEquals(set01, reducer.getInput().getLocalStrategyKeys());
Assert.assertEquals(set01, reducer.getKeys());
Assert.assertTrue(Arrays.equals(reducer.getInput().getLocalStrategySortOrder(), reducer.getSortOrders()));
Assert.assertEquals(set01, reducer.getKeys(0));
Assert.assertTrue(Arrays.equals(reducer.getInput().getLocalStrategySortOrder(), reducer.getSortOrders(0)));
return true;
} else {
return false;
......@@ -337,8 +337,8 @@ public class RelationalQueryCompilerTest extends CompilerTestBase {
Assert.assertEquals(set01, join.getInput1().getLocalStrategyKeys());
Assert.assertEquals(set0, join.getInput2().getLocalStrategyKeys());
Assert.assertTrue(join.getInput1().getLocalStrategySortOrder()[0] == join.getInput2().getLocalStrategySortOrder()[0]);
Assert.assertEquals(set01, reducer.getKeys());
Assert.assertTrue(Arrays.equals(join.getInput1().getLocalStrategySortOrder(), reducer.getSortOrders()));
Assert.assertEquals(set01, reducer.getKeys(0));
Assert.assertTrue(Arrays.equals(join.getInput1().getLocalStrategySortOrder(), reducer.getSortOrders(0)));
return true;
} else {
return false;
......
......@@ -91,12 +91,12 @@ public class WordCountCompilerTest extends CompilerTestBase {
FieldList l = new FieldList(0);
Assert.assertEquals(l, c.getShipStrategyKeys());
Assert.assertEquals(l, c.getLocalStrategyKeys());
Assert.assertTrue(Arrays.equals(c.getLocalStrategySortOrder(), reducer.getSortOrders()));
Assert.assertTrue(Arrays.equals(c.getLocalStrategySortOrder(), reducer.getSortOrders(0)));
// check the combiner
SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getPredecessor();
Assert.assertEquals(DriverStrategy.SORTED_GROUP_COMBINE, combiner.getDriverStrategy());
Assert.assertEquals(l, combiner.getKeys());
Assert.assertEquals(l, combiner.getKeys(0));
Assert.assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
} catch (Exception e) {
......@@ -169,7 +169,8 @@ public class WordCountCompilerTest extends CompilerTestBase {
// check the combiner
SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getPredecessor();
Assert.assertEquals(DriverStrategy.SORTED_GROUP_COMBINE, combiner.getDriverStrategy());
Assert.assertEquals(l, combiner.getKeys());
Assert.assertEquals(l, combiner.getKeys(0));
Assert.assertEquals(l, combiner.getKeys(1));
Assert.assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy());
} catch (Exception e) {
e.printStackTrace();
......
......@@ -143,15 +143,16 @@ public class IterativeKMeansTest extends CompilerTestBase {
assertEquals(DriverStrategy.SORTED_GROUP_COMBINE, combiner.getDriverStrategy());
assertNull(combiner.getInput().getLocalStrategyKeys());
assertNull(combiner.getInput().getLocalStrategySortOrder());
assertEquals(set0, combiner.getKeys());
assertEquals(set0, combiner.getKeys(0));
assertEquals(set0, combiner.getKeys(1));
// check the reducer
assertEquals(ShipStrategyType.PARTITION_HASH, reducer.getInput().getShipStrategy());
assertTrue(reducer.getInput().isOnDynamicPath());
assertEquals(LocalStrategy.COMBININGSORT, reducer.getInput().getLocalStrategy());
assertEquals(DriverStrategy.SORTED_GROUP_REDUCE, reducer.getDriverStrategy());
assertEquals(set0, reducer.getKeys());
assertEquals(set0, reducer.getKeys(0));
assertEquals(set0, reducer.getInput().getLocalStrategyKeys());
assertTrue(Arrays.equals(reducer.getInput().getLocalStrategySortOrder(), reducer.getSortOrders()));
assertTrue(Arrays.equals(reducer.getInput().getLocalStrategySortOrder(), reducer.getSortOrders(0)));
}
}
......@@ -231,6 +231,7 @@ public class CustomCompensatableDanglingPageRankWithCombiner {
combinerConfig.setInputSerializer(vertexWithRankSerializer, 0);
combinerConfig.setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
combinerConfig.setDriverComparator(vertexWithRankComparator, 0);
combinerConfig.setDriverComparator(vertexWithRankComparator, 1);
combinerConfig.setRelativeMemoryDriver((double)coGroupSortMemory/totalMemoryConsumption);
combinerConfig.setOutputSerializer(vertexWithRankSerializer);
combinerConfig.addOutputShipStrategy(ShipStrategyType.PARTITION_HASH);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册