提交 af477563 编写于 作者: S Stephan Ewen

[FLINK-2763] [runtime] Fix hash table spilling partition selection.

上级 16afb8ec
......@@ -198,6 +198,19 @@ public class HashPartition<BT, PT> extends AbstractPagedInputView implements See
public final boolean isInMemory() {
return this.buildSideChannel == null;
}
/**
* Gets the number of memory segments used by this partition, which includes build side
* memory buffers and overflow memory segments.
*
* @return The number of occupied memory segments.
*/
public int getNumOccupiedMemorySegments() {
// either the number of memory segments, or one for spilling
final int numPartitionBuffers = this.partitionBuffers != null ? this.partitionBuffers.length : 1;
return numPartitionBuffers + numOverflowSegments;
}
public int getBuildSideBlockCount() {
return this.partitionBuffers == null ? this.buildSideWriteBuffer.getBlockCount() : this.partitionBuffers.length;
......@@ -284,7 +297,7 @@ public class HashPartition<BT, PT> extends AbstractPagedInputView implements See
throw new RuntimeException("Bug in Hybrid Hash Join: " +
"Request to spill a partition that has already been spilled.");
}
if (getBuildSideBlockCount() + this.numOverflowSegments < 2) {
if (getNumOccupiedMemorySegments() < 2) {
throw new RuntimeException("Bug in Hybrid Hash Join: " +
"Request to spill a partition with less than two buffers.");
}
......
......@@ -1093,8 +1093,8 @@ public class MutableHashTable<BT, PT> implements MemorySegmentSource {
for (int i = 0; i < partitions.size(); i++) {
HashPartition<BT, PT> p = partitions.get(i);
if (p.isInMemory() && p.getBuildSideBlockCount() > largestNumBlocks) {
largestNumBlocks = p.getBuildSideBlockCount();
if (p.isInMemory() && p.getNumOccupiedMemorySegments() > largestNumBlocks) {
largestNumBlocks = p.getNumOccupiedMemorySegments();
largestPartNum = i;
}
}
......
......@@ -21,19 +21,23 @@ package org.apache.flink.runtime.operators.hash;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypePairComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.ByteValueSerializer;
import org.apache.flink.api.common.typeutils.base.LongComparator;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.runtime.TupleComparator;
import org.apache.flink.api.java.typeutils.runtime.TupleSerializer;
import org.apache.flink.api.java.typeutils.runtime.ValueComparator;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync;
import org.apache.flink.types.ByteValue;
import org.apache.flink.util.MutableObjectIterator;
import org.junit.Test;
import org.mockito.Mockito;
import java.io.File;
import java.util.ArrayList;
......@@ -146,6 +150,47 @@ public class HashTableTest {
ioMan.shutdown();
}
}
/**
* This tests the case where no additional partition buffers are used at the point when spilling
* is triggered, testing that overflow bucket buffers are taken into account when deciding which
* partition to spill.
*/
@Test
public void testSpillingFreesOnlyOverflowSegments() {
final IOManager ioMan = new IOManagerAsync();
final TypeSerializer<ByteValue> serializer = ByteValueSerializer.INSTANCE;
final TypeComparator<ByteValue> buildComparator = new ValueComparator<>(true, ByteValue.class);
final TypeComparator<ByteValue> probeComparator = new ValueComparator<>(true, ByteValue.class);
@SuppressWarnings("unchecked")
final TypePairComparator<ByteValue, ByteValue> pairComparator = Mockito.mock(TypePairComparator.class);
try {
final int pageSize = 32*1024;
final int numSegments = 34;
List<MemorySegment> memory = getMemory(numSegments, pageSize);
MutableHashTable<ByteValue, ByteValue> table = new MutableHashTable<>(
serializer, serializer, buildComparator, probeComparator,
pairComparator, memory, ioMan, 1, false);
table.open(new ByteValueIterator(100000000), new ByteValueIterator(1));
table.close();
checkNoTempFilesRemain(ioMan);
}
catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
finally {
ioMan.shutdown();
}
}
// ------------------------------------------------------------------------
// Utilities
......@@ -219,4 +264,28 @@ public class HashTableTest {
}
}
}
private static class ByteValueIterator implements MutableObjectIterator<ByteValue> {
private final long numRecords;
private long value = 0;
ByteValueIterator(long numRecords) {
this.numRecords = numRecords;
}
@Override
public ByteValue next(ByteValue aLong) {
return next();
}
@Override
public ByteValue next() {
if (value++ < numRecords) {
return new ByteValue((byte) 0);
} else {
return null;
}
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册