diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/KeyedProcessFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/KeyedProcessFunction.java index a03480bc6825927d85f1f2c14cd617b418b085dd..eb89362e285e8385f14ff646485580ede7d9288c 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/KeyedProcessFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/KeyedProcessFunction.java @@ -110,6 +110,11 @@ public abstract class KeyedProcessFunction extends AbstractRichFunction * @param value The record to emit. */ public abstract void output(OutputTag outputTag, X value); + + /** + * Get key of the element being processed. + */ + public abstract K getCurrentKey(); } /** @@ -124,6 +129,7 @@ public abstract class KeyedProcessFunction extends AbstractRichFunction /** * Get key of the firing timer. */ + @Override public abstract K getCurrentKey(); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java index 9263be0ac0b4f845900163445b3b91b8295ff749..589ba9d794e6f9b1668f2905e646e0e6d438eb38 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/co/KeyedBroadcastProcessFunction.java @@ -158,6 +158,12 @@ public abstract class KeyedBroadcastProcessFunction extends B * A {@link TimerService} for querying time and registering timers. */ public abstract TimerService timerService(); + + + /** + * Get key of the element being processed. + */ + public abstract KS getCurrentKey(); } /** @@ -174,6 +180,7 @@ public abstract class KeyedBroadcastProcessFunction extends B /** * Get the key of the firing timer. */ + @Override public abstract KS getCurrentKey(); } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/KeyedProcessOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/KeyedProcessOperator.java index b74fdf3492eeca7639fa95990873166ddc2a0b16..b6171c2a8e3042a8a60406b7300764995783485b 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/KeyedProcessOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/KeyedProcessOperator.java @@ -131,6 +131,12 @@ public class KeyedProcessOperator output.collect(outputTag, new StreamRecord<>(value, element.getTimestamp())); } + + @Override + @SuppressWarnings("unchecked") + public K getCurrentKey() { + return (K) KeyedProcessOperator.this.getCurrentKey(); + } } private class OnTimerContextImpl extends KeyedProcessFunction.OnTimerContext { diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java index 5f7bbe2b4f5c97165ec8fb546d327846c1ee37bd..0bfa68618c75f235e0db8ed16706d96bd1f37ce6 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperator.java @@ -288,6 +288,13 @@ public class CoBroadcastWithKeyedOperator } return state; } + + @Override + @SuppressWarnings("unchecked") + public KS getCurrentKey() { + return (KS) CoBroadcastWithKeyedOperator.this.getCurrentKey(); + } + } private class OnTimerContextImpl extends KeyedBroadcastProcessFunction.OnTimerContext { diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/KeyedProcessOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/KeyedProcessOperatorTest.java index c01329e29b245d91f888676f31450b8ac986ce3c..2032916ce19034440c8ef3f81d92919da86cbb27 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/KeyedProcessOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/KeyedProcessOperatorTest.java @@ -23,6 +23,7 @@ import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.streaming.api.TimeDomain; import org.apache.flink.streaming.api.TimerService; @@ -43,6 +44,7 @@ import org.junit.rules.ExpectedException; import java.util.concurrent.ConcurrentLinkedQueue; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; /** * Tests {@link KeyedProcessOperator}. @@ -52,6 +54,48 @@ public class KeyedProcessOperatorTest extends TestLogger { @Rule public ExpectedException expectedException = ExpectedException.none(); + @Test + public void testKeyQuerying() throws Exception { + + class KeyQueryingProcessFunction extends KeyedProcessFunction, String> { + + @Override + public void processElement( + Tuple2 value, + Context ctx, + Collector out) throws Exception { + + assertTrue("Did not get expected key.", ctx.getCurrentKey().equals(value.f0)); + + // we check that we receive this output, to ensure that the assert was actually checked + out.collect(value.f1); + } + } + + KeyedProcessOperator, String> operator = + new KeyedProcessOperator<>(new KeyQueryingProcessFunction()); + + try ( + OneInputStreamOperatorTestHarness, String> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>(operator, (in) -> in.f0 , BasicTypeInfo.INT_TYPE_INFO)) { + + testHarness.setup(); + testHarness.open(); + + testHarness.processElement(new StreamRecord<>(Tuple2.of(5, "5"), 12L)); + testHarness.processElement(new StreamRecord<>(Tuple2.of(42, "42"), 13L)); + + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + expectedOutput.add(new StreamRecord<>("5", 12L)); + expectedOutput.add(new StreamRecord<>("42", 13L)); + + TestHarnessUtil.assertOutputEquals( + "Output was not correct.", + expectedOutput, + testHarness.getOutput()); + } + } + @Test public void testTimestampAndWatermarkQuerying() throws Exception { diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java index c3692d56abb88bd30a384d03fee178153611c79e..715bc9dc6802cd3226590d21ab30c0df98eb3355 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java @@ -25,6 +25,7 @@ import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.state.KeyedStateFunction; import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction; @@ -70,6 +71,56 @@ public class CoBroadcastWithKeyedOperatorTest { BasicTypeInfo.INT_TYPE_INFO ); + @Test + public void testKeyQuerying() throws Exception { + + class KeyQueryingProcessFunction extends KeyedBroadcastProcessFunction, String, String> { + + @Override + public void processElement( + Tuple2 value, + ReadOnlyContext ctx, + Collector out) throws Exception { + assertTrue("Did not get expected key.", ctx.getCurrentKey().equals(value.f0)); + + // we check that we receive this output, to ensure that the assert was actually checked + out.collect(value.f1); + + } + + @Override + public void processBroadcastElement( + String value, + Context ctx, + Collector out) throws Exception { + + } + } + + CoBroadcastWithKeyedOperator, String, String> operator = + new CoBroadcastWithKeyedOperator<>(new KeyQueryingProcessFunction(), Collections.emptyList()); + + try ( + TwoInputStreamOperatorTestHarness, String, String> testHarness = + new KeyedTwoInputStreamOperatorTestHarness<>(operator, (in) -> in.f0 , null, BasicTypeInfo.INT_TYPE_INFO)) { + + testHarness.setup(); + testHarness.open(); + + testHarness.processElement1(new StreamRecord<>(Tuple2.of(5, "5"), 12L)); + testHarness.processElement1(new StreamRecord<>(Tuple2.of(42, "42"), 13L)); + + ConcurrentLinkedQueue expectedOutput = new ConcurrentLinkedQueue<>(); + expectedOutput.add(new StreamRecord<>("5", 12L)); + expectedOutput.add(new StreamRecord<>("42", 13L)); + + TestHarnessUtil.assertOutputEquals( + "Output was not correct.", + expectedOutput, + testHarness.getOutput()); + } + } + /** Test the iteration over the keyed state on the broadcast side. */ @Test public void testAccessToKeyedStateIt() throws Exception {