提交 fa42cdab 编写于 作者: G Gordon Tai 提交者: Robert Metzger

[FLINK-4080] Guarantee exactly-once for Kinesis consumer for failures in the...

[FLINK-4080] Guarantee exactly-once for Kinesis consumer for failures in the middle of aggregated records

This closes #2180
上级 dbe41f48
......@@ -28,6 +28,7 @@ import org.apache.flink.streaming.connectors.kinesis.config.KinesisConfigConstan
import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.proxy.KinesisProxy;
import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema;
import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchemaWrapper;
......@@ -64,7 +65,7 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
* @param <T> the type of data emitted
*/
public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T>
implements CheckpointedAsynchronously<HashMap<KinesisStreamShard, String>>, ResultTypeQueryable<T> {
implements CheckpointedAsynchronously<HashMap<KinesisStreamShard, SequenceNumber>>, ResultTypeQueryable<T> {
private static final long serialVersionUID = 4724006128720664870L;
......@@ -92,10 +93,10 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T>
private transient KinesisDataFetcher fetcher;
/** The sequence numbers of the last fetched data records from Kinesis by this task */
private transient HashMap<KinesisStreamShard, String> lastSequenceNums;
private transient HashMap<KinesisStreamShard, SequenceNumber> lastSequenceNums;
/** The sequence numbers to restore to upon restore from failure */
private transient HashMap<KinesisStreamShard, String> sequenceNumsToRestore;
private transient HashMap<KinesisStreamShard, SequenceNumber> sequenceNumsToRestore;
private volatile boolean hasAssignedShards;
......@@ -227,14 +228,14 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T>
LOG.info("Consumer task {} is restoring sequence numbers from previous checkpointed state", thisConsumerTaskIndex);
}
for (Map.Entry<KinesisStreamShard, String> restoreSequenceNum : sequenceNumsToRestore.entrySet()) {
for (Map.Entry<KinesisStreamShard, SequenceNumber> restoreSequenceNum : sequenceNumsToRestore.entrySet()) {
// advance the corresponding shard to the last known sequence number
fetcher.advanceSequenceNumberTo(restoreSequenceNum.getKey(), restoreSequenceNum.getValue());
}
if (LOG.isInfoEnabled()) {
StringBuilder sb = new StringBuilder();
for (Map.Entry<KinesisStreamShard, String> restoreSequenceNo : sequenceNumsToRestore.entrySet()) {
for (Map.Entry<KinesisStreamShard, SequenceNumber> restoreSequenceNo : sequenceNumsToRestore.entrySet()) {
KinesisStreamShard shard = restoreSequenceNo.getKey();
sb.append(shard.getStreamName()).append(":").append(shard.getShardId())
.append(" -> ").append(restoreSequenceNo.getValue()).append(", ");
......@@ -265,14 +266,14 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T>
}
for (KinesisStreamShard assignedShard : assignedShards) {
fetcher.advanceSequenceNumberTo(assignedShard, sentinelSequenceNum.toString());
fetcher.advanceSequenceNumberTo(assignedShard, sentinelSequenceNum.get());
}
if (LOG.isInfoEnabled()) {
StringBuilder sb = new StringBuilder();
for (KinesisStreamShard assignedShard : assignedShards) {
sb.append(assignedShard.getStreamName()).append(":").append(assignedShard.getShardId())
.append(" -> ").append(sentinelSequenceNum.toString()).append(", ");
.append(" -> ").append(sentinelSequenceNum.get()).append(", ");
}
LOG.info("Advanced the starting sequence numbers of consumer task {}: {}", thisConsumerTaskIndex, sb.toString());
}
......@@ -335,7 +336,7 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T>
// ------------------------------------------------------------------------
@Override
public HashMap<KinesisStreamShard, String> snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
public HashMap<KinesisStreamShard, SequenceNumber> snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
if (lastSequenceNums == null) {
LOG.debug("snapshotState() requested on not yet opened source; returning null.");
return null;
......@@ -351,12 +352,14 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T>
}
@SuppressWarnings("unchecked")
HashMap<KinesisStreamShard, String> currentSequenceNums = (HashMap<KinesisStreamShard, String>) lastSequenceNums.clone();
HashMap<KinesisStreamShard, SequenceNumber> currentSequenceNums =
(HashMap<KinesisStreamShard, SequenceNumber>) lastSequenceNums.clone();
return currentSequenceNums;
}
@Override
public void restoreState(HashMap<KinesisStreamShard, String> restoredState) throws Exception {
public void restoreState(HashMap<KinesisStreamShard, SequenceNumber> restoredState) throws Exception {
sequenceNumsToRestore = restoredState;
}
......
......@@ -20,6 +20,7 @@ package org.apache.flink.streaming.connectors.kinesis.internals;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema;
import org.apache.flink.util.InstantiationUtil;
import org.slf4j.Logger;
......@@ -50,7 +51,7 @@ public class KinesisDataFetcher {
private final String taskName;
/** Information of the shards that this fetcher handles, along with the sequence numbers that they should start from */
private HashMap<KinesisStreamShard, String> assignedShardsWithStartingSequenceNum;
private HashMap<KinesisStreamShard, SequenceNumber> assignedShardsWithStartingSequenceNum;
/** Reference to the thread that executed run() */
private volatile Thread mainThread;
......@@ -71,7 +72,7 @@ public class KinesisDataFetcher {
this.configProps = checkNotNull(configProps);
this.assignedShardsWithStartingSequenceNum = new HashMap<>();
for (KinesisStreamShard shard : assignedShards) {
assignedShardsWithStartingSequenceNum.put(shard, SentinelSequenceNumber.SENTINEL_SEQUENCE_NUMBER_NOT_SET.toString());
assignedShardsWithStartingSequenceNum.put(shard, SentinelSequenceNumber.SENTINEL_SEQUENCE_NUMBER_NOT_SET.get());
}
this.taskName = taskName;
this.error = new AtomicReference<>();
......@@ -83,7 +84,7 @@ public class KinesisDataFetcher {
* @param streamShard the shard to perform the advance on
* @param sequenceNum the sequence number to advance to
*/
public void advanceSequenceNumberTo(KinesisStreamShard streamShard, String sequenceNum) {
public void advanceSequenceNumberTo(KinesisStreamShard streamShard, SequenceNumber sequenceNum) {
if (!assignedShardsWithStartingSequenceNum.containsKey(streamShard)) {
throw new IllegalArgumentException("Can't advance sequence number on a shard we are not going to read.");
}
......@@ -92,7 +93,7 @@ public class KinesisDataFetcher {
public <T> void run(SourceFunction.SourceContext<T> sourceContext,
KinesisDeserializationSchema<T> deserializationSchema,
HashMap<KinesisStreamShard, String> lastSequenceNums) throws Exception {
HashMap<KinesisStreamShard, SequenceNumber> lastSequenceNums) throws Exception {
if (assignedShardsWithStartingSequenceNum == null || assignedShardsWithStartingSequenceNum.size() == 0) {
throw new IllegalArgumentException("No shards set to read for this fetcher");
......@@ -104,7 +105,7 @@ public class KinesisDataFetcher {
// create a thread for each individual shard
ArrayList<ShardConsumerThread<?>> consumerThreads = new ArrayList<>(assignedShardsWithStartingSequenceNum.size());
for (Map.Entry<KinesisStreamShard, String> assignedShard : assignedShardsWithStartingSequenceNum.entrySet()) {
for (Map.Entry<KinesisStreamShard, SequenceNumber> assignedShard : assignedShardsWithStartingSequenceNum.entrySet()) {
ShardConsumerThread<T> thread = new ShardConsumerThread<>(this, configProps, assignedShard.getKey(),
assignedShard.getValue(), sourceContext, InstantiationUtil.clone(deserializationSchema), lastSequenceNums);
thread.setName(String.format("ShardConsumer - %s - %s/%s",
......
......@@ -25,9 +25,11 @@ import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.connectors.kinesis.config.KinesisConfigConstants;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.proxy.KinesisProxy;
import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema;
import java.io.IOException;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.util.HashMap;
......@@ -42,7 +44,7 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
public class ShardConsumerThread<T> extends Thread {
private final SourceFunction.SourceContext<T> sourceContext;
private final KinesisDeserializationSchema<T> deserializer;
private final HashMap<KinesisStreamShard, String> seqNoState;
private final HashMap<KinesisStreamShard, SequenceNumber> seqNoState;
private final KinesisProxy kinesisProxy;
......@@ -52,18 +54,17 @@ public class ShardConsumerThread<T> extends Thread {
private final int maxNumberOfRecordsPerFetch;
private String lastSequenceNum;
private String nextShardItr;
private SequenceNumber lastSequenceNum;
private volatile boolean running = true;
public ShardConsumerThread(KinesisDataFetcher ownerRef,
Properties props,
KinesisStreamShard assignedShard,
String lastSequenceNum,
SequenceNumber lastSequenceNum,
SourceFunction.SourceContext<T> sourceContext,
KinesisDeserializationSchema<T> deserializer,
HashMap<KinesisStreamShard, String> seqNumState) {
HashMap<KinesisStreamShard, SequenceNumber> seqNumState) {
this.ownerRef = checkNotNull(ownerRef);
this.assignedShard = checkNotNull(assignedShard);
this.lastSequenceNum = checkNotNull(lastSequenceNum);
......@@ -79,56 +80,74 @@ public class ShardConsumerThread<T> extends Thread {
@SuppressWarnings("unchecked")
@Override
public void run() {
String nextShardItr;
try {
if (lastSequenceNum.equals(SentinelSequenceNumber.SENTINEL_LATEST_SEQUENCE_NUM.toString())) {
// before infinitely looping, we set the initial nextShardItr appropriately
if (lastSequenceNum.equals(SentinelSequenceNumber.SENTINEL_LATEST_SEQUENCE_NUM.get())) {
// if the shard is already closed, there will be no latest next record to get for this shard
if (assignedShard.isClosed()) {
nextShardItr = null;
} else {
nextShardItr = kinesisProxy.getShardIterator(assignedShard, ShardIteratorType.LATEST.toString(), null);
}
} else if (lastSequenceNum.equals(SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.toString())) {
} else if (lastSequenceNum.equals(SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get())) {
nextShardItr = kinesisProxy.getShardIterator(assignedShard, ShardIteratorType.TRIM_HORIZON.toString(), null);
} else if (lastSequenceNum.equals(SentinelSequenceNumber.SENTINEL_SHARD_ENDING_SEQUENCE_NUM.toString())) {
} else if (lastSequenceNum.equals(SentinelSequenceNumber.SENTINEL_SHARD_ENDING_SEQUENCE_NUM.get())) {
nextShardItr = null;
} else {
nextShardItr = kinesisProxy.getShardIterator(assignedShard, ShardIteratorType.AFTER_SEQUENCE_NUMBER.toString(), lastSequenceNum);
// we will be starting from an actual sequence number (due to restore from failure).
// if the last sequence number refers to an aggregated record, we need to clean up any dangling sub-records
// from the last aggregated record; otherwise, we can simply start iterating from the record right after.
if (lastSequenceNum.isAggregated()) {
String itrForLastAggregatedRecord =
kinesisProxy.getShardIterator(assignedShard, ShardIteratorType.AT_SEQUENCE_NUMBER.toString(), lastSequenceNum.getSequenceNumber());
// get only the last aggregated record
GetRecordsResult getRecordsResult = kinesisProxy.getRecords(itrForLastAggregatedRecord, 1);
List<UserRecord> fetchedRecords = deaggregateRecords(
getRecordsResult.getRecords(),
assignedShard.getStartingHashKey(),
assignedShard.getEndingHashKey());
long lastSubSequenceNum = lastSequenceNum.getSubSequenceNumber();
for (UserRecord record : fetchedRecords) {
// we have found a dangling sub-record if it has a larger subsequence number
// than our last sequence number; if so, collect the record and update state
if (record.getSubSequenceNumber() > lastSubSequenceNum) {
collectRecordAndUpdateState(record);
}
}
// set the nextShardItr so we can continue iterating in the next while loop
nextShardItr = getRecordsResult.getNextShardIterator();
} else {
// the last record was non-aggregated, so we can simply start from the next record
nextShardItr = kinesisProxy.getShardIterator(assignedShard, ShardIteratorType.AFTER_SEQUENCE_NUMBER.toString(), lastSequenceNum.getSequenceNumber());
}
}
while(running) {
if (nextShardItr == null) {
lastSequenceNum = SentinelSequenceNumber.SENTINEL_SHARD_ENDING_SEQUENCE_NUM.toString();
synchronized (sourceContext.getCheckpointLock()) {
seqNoState.put(assignedShard, lastSequenceNum);
seqNoState.put(assignedShard, SentinelSequenceNumber.SENTINEL_SHARD_ENDING_SEQUENCE_NUM.get());
}
break;
} else {
GetRecordsResult getRecordsResult = kinesisProxy.getRecords(nextShardItr, maxNumberOfRecordsPerFetch);
List<Record> fetchedRecords = getRecordsResult.getRecords();
// each of the Kinesis records may be aggregated, so we must deaggregate them before proceeding
fetchedRecords = deaggregateRecords(fetchedRecords, assignedShard.getStartingHashKey(), assignedShard.getEndingHashKey());
for (Record record : fetchedRecords) {
ByteBuffer recordData = record.getData();
byte[] dataBytes = new byte[recordData.remaining()];
recordData.get(dataBytes);
byte[] keyBytes = record.getPartitionKey().getBytes();
List<UserRecord> fetchedRecords = deaggregateRecords(
getRecordsResult.getRecords(),
assignedShard.getStartingHashKey(),
assignedShard.getEndingHashKey());
final T value = deserializer.deserialize(keyBytes, dataBytes,assignedShard.getStreamName(),
record.getSequenceNumber());
synchronized (sourceContext.getCheckpointLock()) {
sourceContext.collect(value);
seqNoState.put(assignedShard, record.getSequenceNumber());
}
lastSequenceNum = record.getSequenceNumber();
for (UserRecord record : fetchedRecords) {
collectRecordAndUpdateState(record);
}
nextShardItr = getRecordsResult.getNextShardIterator();
......@@ -144,8 +163,31 @@ public class ShardConsumerThread<T> extends Thread {
this.interrupt();
}
private void collectRecordAndUpdateState(UserRecord record) throws IOException {
ByteBuffer recordData = record.getData();
byte[] dataBytes = new byte[recordData.remaining()];
recordData.get(dataBytes);
byte[] keyBytes = record.getPartitionKey().getBytes();
final T value = deserializer.deserialize(keyBytes, dataBytes, assignedShard.getStreamName(),
record.getSequenceNumber());
synchronized (sourceContext.getCheckpointLock()) {
sourceContext.collect(value);
if (record.isAggregated()) {
seqNoState.put(
assignedShard,
new SequenceNumber(record.getSequenceNumber(), record.getSubSequenceNumber()));
} else {
seqNoState.put(assignedShard, new SequenceNumber(record.getSequenceNumber()));
}
}
}
@SuppressWarnings("unchecked")
protected static List<Record> deaggregateRecords(List<Record> records, String startingHashKey, String endingHashKey) {
return (List<Record>) (List<?>) UserRecord.deaggregate(records, new BigInteger(startingHashKey), new BigInteger(endingHashKey));
protected static List<UserRecord> deaggregateRecords(List<Record> records, String startingHashKey, String endingHashKey) {
return UserRecord.deaggregate(records, new BigInteger(startingHashKey), new BigInteger(endingHashKey));
}
}
......@@ -29,17 +29,27 @@ public enum SentinelSequenceNumber {
/** Flag value to indicate that the sequence number of a shard is not set. This value is used
* as an initial value in {@link KinesisDataFetcher}'s constructor for all shard's sequence number. */
SENTINEL_SEQUENCE_NUMBER_NOT_SET,
SENTINEL_SEQUENCE_NUMBER_NOT_SET( new SequenceNumber("SEQUENCE_NUMBER_NOT_SET") ),
/** Flag value for shard's sequence numbers to indicate that the
* shard should start to be read from the latest incoming records */
SENTINEL_LATEST_SEQUENCE_NUM,
SENTINEL_LATEST_SEQUENCE_NUM( new SequenceNumber("LATEST_SEQUENCE_NUM") ),
/** Flag value for shard's sequence numbers to indicate that the shard should
* start to be read from the earliest records that haven't expired yet */
SENTINEL_EARLIEST_SEQUENCE_NUM,
SENTINEL_EARLIEST_SEQUENCE_NUM( new SequenceNumber("EARLIEST_SEQUENCE_NUM") ),
/** Flag value to indicate that we have already read the last record of this shard
* (Note: Kinesis shards that have been closed due to a split or merge will have an ending data record) */
SENTINEL_SHARD_ENDING_SEQUENCE_NUM
SENTINEL_SHARD_ENDING_SEQUENCE_NUM( new SequenceNumber("SHARD_ENDING_SEQUENCE_NUM") );
private SequenceNumber sentinel;
SentinelSequenceNumber(SequenceNumber sentinel) {
this.sentinel = sentinel;
}
public SequenceNumber get() {
return sentinel;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.streaming.connectors.kinesis.model;
import java.io.Serializable;
import static org.apache.flink.util.Preconditions.checkNotNull;
/**
* A serializable representation of a Kinesis record's sequence number. It has two fields: the main sequence number,
* and also a subsequence number. If this {@link SequenceNumber} is referring to an aggregated Kinesis record, the
* subsequence number will be a non-negative value representing the order of the sub-record within the aggregation.
*/
public class SequenceNumber implements Serializable {
private static final long serialVersionUID = 876972197938972667L;
private static final String DELIMITER = "-";
private final String sequenceNumber;
private final long subSequenceNumber;
private final int cachedHash;
/**
* Create a new instance for a non-aggregated Kinesis record without a subsequence number.
* @param sequenceNumber the sequence number
*/
public SequenceNumber(String sequenceNumber) {
this(sequenceNumber, -1);
}
/**
* Create a new instance, with the specified sequence number and subsequence number.
* To represent the sequence number for a non-aggregated Kinesis record, the subsequence number should be -1.
* Otherwise, give a non-negative sequence number to represent an aggregated Kinesis record.
*
* @param sequenceNumber the sequence number
* @param subSequenceNumber the subsequence number (-1 to represent non-aggregated Kinesis records)
*/
public SequenceNumber(String sequenceNumber, long subSequenceNumber) {
this.sequenceNumber = checkNotNull(sequenceNumber);
this.subSequenceNumber = subSequenceNumber;
this.cachedHash = 37 * (sequenceNumber.hashCode() + Long.hashCode(subSequenceNumber));
}
public boolean isAggregated() {
return subSequenceNumber >= 0;
}
public String getSequenceNumber() {
return sequenceNumber;
}
public long getSubSequenceNumber() {
return subSequenceNumber;
}
@Override
public String toString() {
if (isAggregated()) {
return sequenceNumber + DELIMITER + subSequenceNumber;
} else {
return sequenceNumber;
}
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof SequenceNumber)) {
return false;
}
if (obj == this) {
return true;
}
SequenceNumber other = (SequenceNumber) obj;
return sequenceNumber.equals(other.getSequenceNumber())
&& (subSequenceNumber == other.getSubSequenceNumber());
}
@Override
public int hashCode() {
return cachedHash;
}
}
......@@ -22,6 +22,7 @@ import org.apache.flink.streaming.connectors.kinesis.config.KinesisConfigConstan
import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.proxy.KinesisProxy;
import org.apache.flink.streaming.connectors.kinesis.testutils.ReferenceKinesisShardTopologies;
import org.apache.flink.streaming.connectors.kinesis.testutils.TestableFlinkKinesisConsumer;
......@@ -344,7 +345,7 @@ public class FlinkKinesisConsumerTest {
dummyConsumer.open(new Configuration());
for (KinesisStreamShard shard : fakeAssignedShardsToThisConsumerTask) {
verify(kinesisDataFetcherMock).advanceSequenceNumberTo(shard, SentinelSequenceNumber.SENTINEL_LATEST_SEQUENCE_NUM.toString());
verify(kinesisDataFetcherMock).advanceSequenceNumberTo(shard, SentinelSequenceNumber.SENTINEL_LATEST_SEQUENCE_NUM.get());
}
}
......@@ -380,7 +381,7 @@ public class FlinkKinesisConsumerTest {
dummyConsumer.open(new Configuration());
for (KinesisStreamShard shard : fakeAssignedShardsToThisConsumerTask) {
verify(kinesisDataFetcherMock).advanceSequenceNumberTo(shard, SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.toString());
verify(kinesisDataFetcherMock).advanceSequenceNumberTo(shard, SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get());
}
}
......@@ -414,14 +415,14 @@ public class FlinkKinesisConsumerTest {
null, null, false, false);
// generate random UUIDs as sequence numbers of last checkpointed state for each assigned shard
ArrayList<String> listOfSeqNumIfAssignedShards = new ArrayList<>(fakeAssignedShardsToThisConsumerTask.size());
ArrayList<SequenceNumber> listOfSeqNumOfAssignedShards = new ArrayList<>(fakeAssignedShardsToThisConsumerTask.size());
for (KinesisStreamShard shard : fakeAssignedShardsToThisConsumerTask) {
listOfSeqNumIfAssignedShards.add(UUID.randomUUID().toString());
listOfSeqNumOfAssignedShards.add(new SequenceNumber(UUID.randomUUID().toString()));
}
HashMap<KinesisStreamShard, String> fakeRestoredState = new HashMap<>();
HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = new HashMap<>();
for (int i=0; i<fakeAssignedShardsToThisConsumerTask.size(); i++) {
fakeRestoredState.put(fakeAssignedShardsToThisConsumerTask.get(i), listOfSeqNumIfAssignedShards.get(i));
fakeRestoredState.put(fakeAssignedShardsToThisConsumerTask.get(i), listOfSeqNumOfAssignedShards.get(i));
}
dummyConsumer.restoreState(fakeRestoredState);
......@@ -430,7 +431,7 @@ public class FlinkKinesisConsumerTest {
for (int i=0; i<fakeAssignedShardsToThisConsumerTask.size(); i++) {
verify(kinesisDataFetcherMock).advanceSequenceNumberTo(
fakeAssignedShardsToThisConsumerTask.get(i),
listOfSeqNumIfAssignedShards.get(i));
listOfSeqNumOfAssignedShards.get(i));
}
}
......
......@@ -18,6 +18,7 @@
package org.apache.flink.streaming.connectors.kinesis.internals;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.testutils.ReferenceKinesisShardTopologies;
import org.junit.Rule;
import org.junit.Test;
......@@ -42,7 +43,7 @@ public class KinesisDataFetcherTest {
KinesisDataFetcher fetcherUnderTest = new KinesisDataFetcher(assignedShardsToThisFetcher, new Properties(), "fake-task-name");
// advance the fetcher on a shard that it does not own
fetcherUnderTest.advanceSequenceNumberTo(fakeCompleteListOfShards.get(2), "fake-seq-num");
fetcherUnderTest.advanceSequenceNumberTo(fakeCompleteListOfShards.get(2), new SequenceNumber("fake-seq-num"));
}
}
......@@ -17,12 +17,18 @@
package org.apache.flink.streaming.connectors.kinesis.internals;
import com.amazonaws.services.kinesis.model.*;
import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
import com.amazonaws.services.kinesis.model.HashKeyRange;
import com.amazonaws.services.kinesis.model.Record;
import com.amazonaws.services.kinesis.model.Shard;
import org.apache.commons.lang.StringUtils;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
import org.apache.flink.streaming.connectors.kinesis.proxy.KinesisProxy;
import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchemaWrapper;
import org.apache.flink.streaming.util.serialization.SimpleStringSchema;
......@@ -35,11 +41,12 @@ import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Properties;
import java.util.HashMap;
import java.util.UUID;
import java.util.LinkedList;
import static org.junit.Assert.assertEquals;
......@@ -111,31 +118,31 @@ public class ShardConsumerThreadTest {
// so we are mocking the static deaggregateRecords() to return the original list of records
PowerMockito.mockStatic(ShardConsumerThread.class);
PowerMockito.when(ShardConsumerThread.deaggregateRecords(Matchers.anyListOf(Record.class), Matchers.anyString(), Matchers.anyString()))
.thenReturn(getRecordsResultFirst.getRecords())
.thenReturn(getRecordsResultSecond.getRecords())
.thenReturn(getRecordsResultThird.getRecords())
.thenReturn(getRecordsResultFourth.getRecords())
.thenReturn(getRecordsResultFifth.getRecords())
.thenReturn(getRecordsResultFinal.getRecords());
.thenReturn(convertRecordsToUserRecords(getRecordsResultFirst.getRecords()))
.thenReturn(convertRecordsToUserRecords(getRecordsResultSecond.getRecords()))
.thenReturn(convertRecordsToUserRecords(getRecordsResultThird.getRecords()))
.thenReturn(convertRecordsToUserRecords(getRecordsResultFourth.getRecords()))
.thenReturn(convertRecordsToUserRecords(getRecordsResultFifth.getRecords()))
.thenReturn(convertRecordsToUserRecords(getRecordsResultFinal.getRecords()));
// ------------------------------------------------------------------------------------------
Properties testConsumerConfig = new Properties();
HashMap<KinesisStreamShard, String> seqNumState = new HashMap<>();
HashMap<KinesisStreamShard, SequenceNumber> seqNumState = new HashMap<>();
DummySourceContext dummySourceContext = new DummySourceContext();
ShardConsumerThread dummyShardConsumerThread = getDummyShardConsumerThreadWithMockedKinesisProxy(
dummySourceContext, kinesisProxyMock, Mockito.mock(KinesisDataFetcher.class),
testConsumerConfig, assignedShardUnderTest, "fake-last-seq-num", seqNumState);
testConsumerConfig, assignedShardUnderTest, new SequenceNumber("fake-last-seq-num"), seqNumState);
dummyShardConsumerThread.run();
// the final sequence number state for the assigned shard to this consumer thread
// should store SENTINEL_SHARD_ENDING_SEQUENCE_NUMBER since the final nextShardItr should be null
assertEquals(seqNumState.get(assignedShardUnderTest), SentinelSequenceNumber.SENTINEL_SHARD_ENDING_SEQUENCE_NUM.toString());
assertEquals(SentinelSequenceNumber.SENTINEL_SHARD_ENDING_SEQUENCE_NUM.get(), seqNumState.get(assignedShardUnderTest));
// the number of elements collected should equal the number of records generated by mocked KinesisProxy
assertEquals(dummySourceContext.getNumOfElementsCollected(), totalRecordCount);
assertEquals(totalRecordCount, dummySourceContext.getNumOfElementsCollected());
}
private ShardConsumerThread getDummyShardConsumerThreadWithMockedKinesisProxy(
......@@ -144,8 +151,8 @@ public class ShardConsumerThreadTest {
KinesisDataFetcher owningFetcherRefMock,
Properties testConsumerConfig,
KinesisStreamShard assignedShard,
String lastSequenceNum,
HashMap<KinesisStreamShard, String> seqNumState) {
SequenceNumber lastSequenceNum,
HashMap<KinesisStreamShard, SequenceNumber> seqNumState) {
try {
PowerMockito.whenNew(KinesisProxy.class).withArguments(testConsumerConfig).thenReturn(kinesisProxyMock);
......@@ -159,7 +166,7 @@ public class ShardConsumerThreadTest {
private List<Record> generateFakeListOfRecordsFromToIncluding(int startingSeq, int endingSeq) {
List<Record> fakeListOfRecords = new LinkedList<>();
for (int i=0; i <= (endingSeq - startingSeq); i++) {
for (int i=startingSeq; i <= endingSeq; i++) {
fakeListOfRecords.add(new Record()
.withData(ByteBuffer.wrap(String.valueOf(i).getBytes()))
.withPartitionKey(UUID.randomUUID().toString()) // the partition key assigned doesn't matter here
......@@ -168,6 +175,14 @@ public class ShardConsumerThreadTest {
return fakeListOfRecords;
}
private List<UserRecord> convertRecordsToUserRecords(List<Record> records) {
List<UserRecord> converted = new ArrayList<>(records.size());
for (Record record : records) {
converted.add(new UserRecord(record));
}
return converted;
}
private static class DummySourceContext implements SourceFunction.SourceContext<String> {
private static final Object lock = new Object();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册