提交 22ac65ba 编写于 作者: G Greg Hogan 提交者: Stephan Ewen

[FLINK-2897] [runtime] Use distinct initial indices for OutputEmitter round-robin

This closes #1292
上级 868f97cf
......@@ -1256,18 +1256,19 @@ public class BatchTask<S extends Function, OT> extends AbstractInvokable impleme
{
// create the OutputEmitter from output ship strategy
final ShipStrategyType strategy = config.getOutputShipStrategy(i);
final int indexInSubtaskGroup = task.getIndexInSubtaskGroup();
final TypeComparatorFactory<T> compFactory = config.getOutputComparator(i, cl);
final ChannelSelector<SerializationDelegate<T>> oe;
if (compFactory == null) {
oe = new OutputEmitter<T>(strategy);
oe = new OutputEmitter<T>(strategy, indexInSubtaskGroup);
}
else {
final DataDistribution dataDist = config.getOutputDataDistribution(i, cl);
final Partitioner<?> partitioner = config.getOutputPartitioner(i, cl);
final TypeComparator<T> comparator = compFactory.createComparator();
oe = new OutputEmitter<T>(strategy, comparator, partitioner, dataDist);
oe = new OutputEmitter<T>(strategy, indexInSubtaskGroup, comparator, partitioner, dataDist);
}
final RecordWriter<SerializationDelegate<T>> recordWriter =
......
......@@ -28,9 +28,9 @@ import org.apache.flink.runtime.plugable.SerializationDelegate;
public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T>> {
private final ShipStrategyType strategy; // the shipping strategy used by this output emitter
private int[] channels; // the reused array defining target channels
private int nextChannelToSendTo = 0; // counter to go over channels round robin
private final TypeComparator<T> comparator; // the comparator for hashing / sorting
......@@ -47,16 +47,17 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
* Creates a new channel selector that distributes data round robin.
*/
public OutputEmitter() {
this(ShipStrategyType.NONE);
this(ShipStrategyType.NONE, 0);
}
/**
* Creates a new channel selector that uses the given strategy (broadcasting, partitioning, ...).
* Creates a new channel selector that uses the given strategy (broadcasting, partitioning, ...)
* and uses the supplied task index perform a round robin distribution.
*
* @param strategy The distribution strategy to be used.
*/
public OutputEmitter(ShipStrategyType strategy) {
this(strategy, null);
public OutputEmitter(ShipStrategyType strategy, int indexInSubtaskGroup) {
this(strategy, indexInSubtaskGroup, null, null, null);
}
/**
......@@ -67,7 +68,7 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
* @param comparator The comparator used to hash / compare the records.
*/
public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator) {
this(strategy, comparator, null, null);
this(strategy, 0, comparator, null, null);
}
/**
......@@ -79,30 +80,33 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
* @param distr The distribution pattern used in the case of a range partitioning.
*/
public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator, DataDistribution distr) {
this(strategy, comparator, null, distr);
this(strategy, 0, comparator, null, distr);
}
public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator, Partitioner<?> partitioner) {
this(strategy, comparator, partitioner, null);
this(strategy, 0, comparator, partitioner, null);
}
@SuppressWarnings("unchecked")
public OutputEmitter(ShipStrategyType strategy, TypeComparator<T> comparator, Partitioner<?> partitioner, DataDistribution distr) {
public OutputEmitter(ShipStrategyType strategy, int indexInSubtaskGroup, TypeComparator<T> comparator, Partitioner<?> partitioner, DataDistribution distr) {
if (strategy == null) {
throw new NullPointerException();
}
this.strategy = strategy;
this.nextChannelToSendTo = indexInSubtaskGroup;
this.comparator = comparator;
this.partitioner = (Partitioner<Object>) partitioner;
switch (strategy) {
case PARTITION_CUSTOM:
extractedKeys = new Object[1];
case FORWARD:
case PARTITION_HASH:
case PARTITION_RANGE:
case PARTITION_RANDOM:
case PARTITION_FORCED_REBALANCE:
case PARTITION_CUSTOM:
channels = new int[1];
case BROADCAST:
break;
default:
......@@ -125,6 +129,7 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
public final int[] selectChannels(SerializationDelegate<T> record, int numberOfChannels) {
switch (strategy) {
case FORWARD:
return forward();
case PARTITION_RANDOM:
case PARTITION_FORCED_REBALANCE:
return robin(numberOfChannels);
......@@ -143,16 +148,24 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
// --------------------------------------------------------------------------------------------
private int[] forward() {
return this.channels;
}
private int[] robin(int numberOfChannels) {
if (this.channels == null || this.channels.length != 1) {
this.channels = new int[1];
int nextChannel = this.nextChannelToSendTo;
if (nextChannel >= numberOfChannels) {
if (nextChannel == numberOfChannels) {
nextChannel = 0;
} else {
nextChannel %= numberOfChannels;
}
}
int nextChannel = nextChannelToSendTo + 1;
nextChannel = nextChannel < numberOfChannels ? nextChannel : 0;
this.nextChannelToSendTo = nextChannel;
this.channels[0] = nextChannel;
this.nextChannelToSendTo = nextChannel + 1;
return this.channels;
}
......@@ -168,10 +181,6 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
}
private int[] hashPartitionDefault(T record, int numberOfChannels) {
if (channels == null || channels.length != 1) {
channels = new int[1];
}
int hash = this.comparator.hash(record);
hash = murmurHash(hash);
......@@ -212,11 +221,6 @@ public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T
}
private int[] customPartition(T record, int numberOfChannels) {
if (channels == null) {
channels = new int[1];
extractedKeys = new Object[1];
}
try {
if (comparator.extractKeys(record, extractedKeys, 0) == 1) {
final Object key = extractedKeys[0];
......
......@@ -151,7 +151,7 @@ public class OutputEmitterTest extends TestCase {
assertTrue(chans.length == 1);
assertTrue(chans[0] >= 0 && chans[0] <= numChans-1);
}
@Test
public void testForward() {
// Test for IntValue
......@@ -159,29 +159,27 @@ public class OutputEmitterTest extends TestCase {
final TypeComparator<Record> intComp = new RecordComparatorFactory(new int[] {0}, new Class[] {IntValue.class}).createComparator();
final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(ShipStrategyType.FORWARD, intComp);
final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(new RecordSerializerFactory().getSerializer());
int numChannels = 100;
int numRecords = 50000;
int numRecords = 50000 + numChannels / 2;
int[] hit = new int[numChannels];
for (int i = 0; i < numRecords; i++) {
IntValue k = new IntValue(i);
Record rec = new Record(k);
delegate.setInstance(rec);
int[] chans = oe1.selectChannels(delegate, hit.length);
for(int j=0; j < chans.length; j++) {
hit[chans[j]]++;
}
}
int cnt = 0;
for (int i = 0; i < hit.length; i++) {
assertTrue(hit[i] == (numRecords/numChannels) || hit[i] == (numRecords/numChannels)-1);
cnt += hit[i];
assertTrue(hit[0] == numRecords);
for (int i = 1; i < hit.length; i++) {
assertTrue(hit[i] == 0);
}
assertTrue(cnt == numRecords);
// Test for StringValue
@SuppressWarnings("unchecked")
......@@ -189,15 +187,79 @@ public class OutputEmitterTest extends TestCase {
final ChannelSelector<SerializationDelegate<Record>> oe2 = new OutputEmitter<Record>(ShipStrategyType.FORWARD, stringComp);
numChannels = 100;
numRecords = 10000;
numRecords = 10000 + numChannels / 2;
hit = new int[numChannels];
for (int i = 0; i < numRecords; i++) {
StringValue k = new StringValue(i + "");
Record rec = new Record(k);
delegate.setInstance(rec);
int[] chans = oe2.selectChannels(delegate, hit.length);
for(int j=0; j < chans.length; j++) {
hit[chans[j]]++;
}
}
assertTrue(hit[0] == numRecords);
for (int i = 1; i < hit.length; i++) {
assertTrue(hit[i] == 0);
}
}
@Test
public void testForcedRebalance() {
// Test for IntValue
int numChannels = 100;
int toTaskIndex = numChannels * 6/7;
int fromTaskIndex = toTaskIndex + numChannels;
int extraRecords = numChannels * 1/3;
int numRecords = 50000 + extraRecords;
final ChannelSelector<SerializationDelegate<Record>> oe1 = new OutputEmitter<Record>(ShipStrategyType.PARTITION_FORCED_REBALANCE, fromTaskIndex);
final SerializationDelegate<Record> delegate = new SerializationDelegate<Record>(new RecordSerializerFactory().getSerializer());
int[] hit = new int[numChannels];
for (int i = 0; i < numRecords; i++) {
IntValue k = new IntValue(i);
Record rec = new Record(k);
delegate.setInstance(rec);
int[] chans = oe1.selectChannels(delegate, hit.length);
for(int j=0; j < chans.length; j++) {
hit[chans[j]]++;
}
}
int cnt = 0;
for (int i = 0; i < hit.length; i++) {
if (toTaskIndex <= i || i < toTaskIndex+extraRecords-numChannels) {
assertTrue(hit[i] == (numRecords/numChannels)+1);
} else {
assertTrue(hit[i] == numRecords/numChannels);
}
cnt += hit[i];
}
assertTrue(cnt == numRecords);
// Test for StringValue
numChannels = 100;
toTaskIndex = numChannels / 5;
fromTaskIndex = toTaskIndex + 2 * numChannels;
extraRecords = numChannels * 2/9;
numRecords = 10000 + extraRecords;
final ChannelSelector<SerializationDelegate<Record>> oe2 = new OutputEmitter<Record>(ShipStrategyType.PARTITION_FORCED_REBALANCE, fromTaskIndex);
hit = new int[numChannels];
for (int i = 0; i < numRecords; i++) {
StringValue k = new StringValue(i + "");
Record rec = new Record(k);
delegate.setInstance(rec);
int[] chans = oe2.selectChannels(delegate, hit.length);
for(int j=0; j < chans.length; j++) {
hit[chans[j]]++;
......@@ -206,11 +268,14 @@ public class OutputEmitterTest extends TestCase {
cnt = 0;
for (int i = 0; i < hit.length; i++) {
assertTrue(hit[i] == (numRecords/numChannels) || hit[i] == (numRecords/numChannels)-1);
if (toTaskIndex <= i && i < toTaskIndex+extraRecords) {
assertTrue(hit[i] == (numRecords/numChannels)+1);
} else {
assertTrue(hit[i] == numRecords/numChannels);
}
cnt += hit[i];
}
assertTrue(cnt == numRecords);
}
@Test
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册