diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/IndexedInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/IndexedInputGate.java index be837f4b1395f6d3f7261e1d8bfb1d1ee9665f77..686c01524fe86460a56b335dccbe807f38cff562 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/IndexedInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/IndexedInputGate.java @@ -23,6 +23,8 @@ import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import java.io.IOException; +import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY; + /** An {@link InputGate} with a specific index. */ public abstract class IndexedInputGate extends InputGate implements CheckpointableInput { /** Returns the index of this input gate. Only supported on */ @@ -30,6 +32,9 @@ public abstract class IndexedInputGate extends InputGate implements Checkpointab @Override public void checkpointStarted(CheckpointBarrier barrier) throws CheckpointException { + if (!getStateConsumedFuture().isDone()) { + throw new CheckpointException(CHECKPOINT_DECLINED_TASK_NOT_READY); + } for (int index = 0, numChannels = getNumberOfInputChannels(); index < numChannels; index++) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java index d9cad1747a5f7ff89a8503da5f962da425ae0fd1..68ccd0084e0a3df7d9602e08d0b971dc74f92aff 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java @@ -21,6 +21,7 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.metrics.groups.UnregisteredMetricsGroup; +import org.apache.flink.runtime.checkpoint.CheckpointException; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; import org.apache.flink.runtime.clusterframework.types.ResourceID; @@ -32,6 +33,7 @@ import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder; import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.TaskEventPublisher; import org.apache.flink.runtime.io.network.TestingConnectionManager; +import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferCompressor; import org.apache.flink.runtime.io.network.buffer.BufferDecompressor; @@ -70,11 +72,15 @@ import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; import static java.util.Arrays.asList; +import static org.apache.flink.runtime.checkpoint.CheckpointOptions.alignedNoTimeout; +import static org.apache.flink.runtime.checkpoint.CheckpointType.CHECKPOINT; import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createLocalInputChannel; import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createSingleInputGate; import static org.apache.flink.runtime.io.network.partition.InputGateFairnessTest.setupInputGate; import static org.apache.flink.runtime.io.network.util.TestBufferFactory.createBuffer; +import static org.apache.flink.runtime.state.CheckpointStorageLocationReference.getDefault; import static org.apache.flink.runtime.util.NettyShuffleDescriptorBuilder.createRemoteWithIdAndLocation; +import static org.apache.flink.util.Preconditions.checkState; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; @@ -88,6 +94,14 @@ import static org.junit.Assert.fail; /** Tests for {@link SingleInputGate}. */ public class SingleInputGateTest extends InputGateTestBase { + @Test(expected = CheckpointException.class) + public void testCheckpointsDeclinedUnlessStateConsumed() throws CheckpointException { + SingleInputGate gate = createInputGate(createNettyShuffleEnvironment()); + checkState(!gate.getStateConsumedFuture().isDone()); + gate.checkpointStarted( + new CheckpointBarrier(1L, 1L, alignedNoTimeout(CHECKPOINT, getDefault()))); + } + /** * Tests {@link InputGate#setup()} should create the respective {@link BufferPool} and assign * exclusive buffers for {@link RemoteInputChannel}s, but should not request partitions. diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java index 903998ae62f3850a0277a09097cabab5d872856e..24027a60576a03c822a5be1c240ab0804c9eb006 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java @@ -38,7 +38,6 @@ import org.apache.flink.streaming.api.functions.KeyedProcessFunction; import org.apache.flink.streaming.api.functions.co.RichCoFlatMapFunction; import org.apache.flink.util.Collector; -import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -125,8 +124,9 @@ public class UnalignedCheckpointITCase extends UnalignedCheckpointTestBase { }, new Object[] {"Parallel cogroup, p = 5", createCogroupSettings(5)}, new Object[] {"Parallel cogroup, p = 10", createCogroupSettings(10)}, - new Object[] {"Parallel union, p = 5", createUnionSettings(5)}, - new Object[] {"Parallel union, p = 10", createUnionSettings(10)}, + // todo: enable after completely fixing FLINK-20654 + // new Object[] {"Parallel union, p = 5", createUnionSettings(5)}, + // new Object[] {"Parallel union, p = 10", createUnionSettings(10)}, }; } @@ -189,7 +189,6 @@ public class UnalignedCheckpointITCase extends UnalignedCheckpointTestBase { } @Test - @Ignore public void execute() throws Exception { execute(settings); }