提交 e4d05f72 编写于 作者: S Stephan Ewen

[FLINK-3312] Add accessors for various state types to RuntimeContext

上级 6f755961
......@@ -209,7 +209,7 @@ public class DBStateCheckpointingTest extends StreamFaultToleranceTestBase {
failurePos = (new Random().nextLong() % (failurePosMax - failurePosMin)) + failurePosMin;
count = 0;
sum = getRuntimeContext().getPartitionedState(
sum = getRuntimeContext().getState(
new ValueStateDescriptor<>("my_state", 0L, LongSerializer.INSTANCE));
}
......@@ -237,11 +237,11 @@ public class DBStateCheckpointingTest extends StreamFaultToleranceTestBase {
@Override
public void open(Configuration parameters) throws IOException {
aCounts = getRuntimeContext().getPartitionedState(
aCounts = getRuntimeContext().getState(
new ValueStateDescriptor<>("a", NonSerializableLong.of(0L),
new KryoSerializer<>(NonSerializableLong.class, new ExecutionConfig())));
bCounts = getRuntimeContext().getPartitionedState(
bCounts = getRuntimeContext().getState(
new ValueStateDescriptor<>("b", 0L, LongSerializer.INSTANCE));
}
......
......@@ -31,10 +31,13 @@ import org.apache.flink.api.common.accumulators.Histogram;
import org.apache.flink.api.common.accumulators.IntCounter;
import org.apache.flink.api.common.accumulators.LongCounter;
import org.apache.flink.api.common.cache.DistributedCache;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.OperatorState;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
/**
......@@ -188,13 +191,15 @@ public interface RuntimeContext {
*/
DistributedCache getDistributedCache();
// --------------------------------------------------------------------------------------------
// ------------------------------------------------------------------------
// Methods for accessing state
// ------------------------------------------------------------------------
/**
* Gets the partitioned state, which is only accessible if the function is executed on
* a KeyedStream. When interacting with the state only the instance bound to the key of the
* element currently processed by the function is changed.
* Each operator may maintain multiple partitioned states, addressed with different names.
* Gets a handle to the system's key/value state. The key/value state is only accessible
* if the function is executed on a KeyedStream. On each access, the state exposes the value
* for the the key of the element currently processed by the function.
* Each function may have multiple partitioned states, addressed with different names.
*
* <p>Because the scope of each value is the key of the currently processed element,
* and the elements are distributed by the Flink runtime, the system can transparently
......@@ -213,31 +218,111 @@ public interface RuntimeContext {
* private ValueState<Long> count;
*
* public void open(Configuration cfg) {
* state = getRuntimeContext().getPartitionedState(
* state = getRuntimeContext().getState(
* new ValueStateDescriptor<Long>("count", 0L, LongSerializer.INSTANCE));
* }
*
* public Tuple2<MyType, Long> map(MyType value) {
* long count = state.value();
* state.update(value + 1);
* long count = state.value() + 1;
* state.update(value);
* return new Tuple2<>(value, count);
* }
* });
* }</pre>
*
* @param stateProperties The descriptor defining the properties of the stats.
*
* @param <T> The type of value stored in the state.
*
* @return The partitioned state object.
*
* @throws UnsupportedOperationException Thrown, if no partitioned state is available for the
* function (function is not part of a KeyedStream).
*/
<T> ValueState<T> getState(ValueStateDescriptor<T> stateProperties);
/**
* Gets a handle to the system's key/value list state. This state is similar to the state
* accessed via {@link #getState(ValueStateDescriptor)}, but is optimized for state that
* holds lists. One can adds elements to the list, or retrieve the list as a whole.
*
* <p>This state is only accessible if the function is executed on a KeyedStream.
*
* <pre>{@code
* DataStream<MyType> stream = ...;
* KeyedStream<MyType> keyedStream = stream.keyBy("id");
*
* keyedStream.map(new RichFlatMapFunction<MyType, List<MyType>>() {
*
* private ListState<MyType> state;
*
* public void open(Configuration cfg) {
* state = getRuntimeContext().getListState(
* new ListStateDescriptor<>("myState", MyType.class));
* }
*
* public void flatMap(MyType value, Collector<MyType> out) {
* if (value.isDivider()) {
* for (MyType t : state.get()) {
* out.collect(t);
* }
* } else {
* state.add(value);
* }
* }
* });
* }</pre>
*
* @param stateDescriptor The StateDescriptor that contains the name and type of the
* state that is being accessed.
* @param stateProperties The descriptor defining the properties of the stats.
*
* @param <S> The type of the state.
* @param <T> The type of value stored in the state.
*
* @return The partitioned state object.
*
* @throws UnsupportedOperationException Thrown, if no partitioned state is available for the
* function (function is not part os a KeyedStream).
*/
<S extends State> S getPartitionedState(StateDescriptor<S> stateDescriptor);
<T> ListState<T> getListState(ListStateDescriptor<T> stateProperties);
/**
* Gets a handle to the system's key/value list state. This state is similar to the state
* accessed via {@link #getState(ValueStateDescriptor)}, but is optimized for state that
* aggregates values.
*
* <p>This state is only accessible if the function is executed on a KeyedStream.
*
* <pre>{@code
* DataStream<MyType> stream = ...;
* KeyedStream<MyType> keyedStream = stream.keyBy("id");
*
* keyedStream.map(new RichMapFunction<MyType, List<MyType>>() {
*
* private ReducingState<Long> sum;
*
* public void open(Configuration cfg) {
* state = getRuntimeContext().getReducingState(
* new ReducingStateDescriptor<>("sum", MyType.class, 0L, (a, b) -> a + b));
* }
*
* public Tuple2<MyType, Long> map(MyType value) {
* sum.add(value.count());
* return new Tuple2<>(value, sum.get());
* }
* });
*
* }</pre>
*
* @param stateProperties The descriptor defining the properties of the stats.
*
* @param <T> The type of value stored in the state.
*
* @return The partitioned state object.
*
* @throws UnsupportedOperationException Thrown, if no partitioned state is available for the
* function (function is not part of a KeyedStream).
*/
<T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> stateProperties);
/**
* Gets the key/value state, which is only accessible if the function is executed on
* a KeyedStream. Upon calling {@link ValueState#value()}, the key/value state will
......@@ -292,7 +377,7 @@ public interface RuntimeContext {
* @throws UnsupportedOperationException Thrown, if no key/value state is available for the
* function (function is not part os a KeyedStream).
*
* @deprecated Use the more expressive {@link #getPartitionedState(StateDescriptor)} instead.
* @deprecated Use the more expressive {@link #getState(ValueStateDescriptor)} instead.
*/
@Deprecated
<S> OperatorState<S> getKeyValueState(String name, Class<S> stateType, S defaultState);
......@@ -344,7 +429,7 @@ public interface RuntimeContext {
* @throws UnsupportedOperationException Thrown, if no key/value state is available for the
* function (function is not part os a KeyedStream).
*
* @deprecated Use the more expressive {@link #getPartitionedState(StateDescriptor)} instead.
* @deprecated Use the more expressive {@link #getState(ValueStateDescriptor)} instead.
*/
@Deprecated
<S> OperatorState<S> getKeyValueState(String name, TypeInformation<S> stateType, S defaultState);
......
......@@ -34,10 +34,13 @@ import org.apache.flink.api.common.accumulators.IntCounter;
import org.apache.flink.api.common.accumulators.LongCounter;
import org.apache.flink.api.common.cache.DistributedCache;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.OperatorState;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.core.fs.Path;
......@@ -172,10 +175,21 @@ public abstract class AbstractRuntimeUDFContext implements RuntimeContext {
}
@Override
public <S extends State> S getPartitionedState(StateDescriptor<S> stateDescriptor) {
public <T> ValueState<T> getState(ValueStateDescriptor<T> stateProperties) {
throw new UnsupportedOperationException(
"This state is only accessible by functions executed on a KeyedStream");
}
@Override
public <T> ListState<T> getListState(ListStateDescriptor<T> stateProperties) {
throw new UnsupportedOperationException(
"This state is only accessible by functions executed on a KeyedStream");
}
@Override
public <T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> stateProperties) {
throw new UnsupportedOperationException(
"This state is only accessible by functions executed on a KeyedStream");
}
@Override
......@@ -191,5 +205,4 @@ public abstract class AbstractRuntimeUDFContext implements RuntimeContext {
throw new UnsupportedOperationException(
"This state is only accessible by functions executed on a KeyedStream");
}
}
......@@ -26,10 +26,13 @@ import org.apache.flink.api.common.accumulators.IntCounter;
import org.apache.flink.api.common.accumulators.LongCounter;
import org.apache.flink.api.common.cache.DistributedCache;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.OperatorState;
import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.runtime.operators.testutils.MockEnvironment;
......@@ -54,7 +57,7 @@ public class MockRuntimeContext extends StreamingRuntimeContext {
this.indexOfThisSubtask = indexOfThisSubtask;
}
private static class MockStreamOperator extends AbstractStreamOperator {
private static class MockStreamOperator extends AbstractStreamOperator<Integer> {
private static final long serialVersionUID = -1153976702711944427L;
@Override
......@@ -154,12 +157,22 @@ public class MockRuntimeContext extends StreamingRuntimeContext {
}
@Override
public <S> ValueState<S> getKeyValueState(String name, TypeInformation<S> stateType, S defaultState) {
public <S> OperatorState<S> getKeyValueState(String name, TypeInformation<S> stateType, S defaultState) {
throw new UnsupportedOperationException();
}
@Override
public <S extends State> S getPartitionedState(StateDescriptor<S> stateDescriptor) {
public <T> ValueState<T> getState(ValueStateDescriptor<T> stateProperties) {
throw new UnsupportedOperationException();
}
@Override
public <T> ListState<T> getListState(ListStateDescriptor<T> stateProperties) {
throw new UnsupportedOperationException();
}
@Override
public <T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> stateProperties) {
throw new UnsupportedOperationException();
}
}
......@@ -21,9 +21,12 @@ package org.apache.flink.streaming.api.operators;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.util.AbstractRuntimeUDFContext;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.OperatorState;
import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.TypeExtractor;
......@@ -106,11 +109,32 @@ public class StreamingRuntimeContext extends AbstractRuntimeUDFContext {
// ------------------------------------------------------------------------
@Override
public <S extends State> S getPartitionedState(StateDescriptor<S> stateDescriptor) {
public <T> ValueState<T> getState(ValueStateDescriptor<T> stateProperties) {
requireNonNull(stateProperties, "The state properties must not be null");
try {
return operator.getPartitionedState(stateDescriptor);
return operator.getPartitionedState(stateProperties);
} catch (Exception e) {
throw new RuntimeException("Error while getting state.", e);
throw new RuntimeException("Error while getting state", e);
}
}
@Override
public <T> ListState<T> getListState(ListStateDescriptor<T> stateProperties) {
requireNonNull(stateProperties, "The state properties must not be null");
try {
return operator.getPartitionedState(stateProperties);
} catch (Exception e) {
throw new RuntimeException("Error while getting state", e);
}
}
@Override
public <T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> stateProperties) {
requireNonNull(stateProperties, "The state properties must not be null");
try {
return operator.getPartitionedState(stateProperties);
} catch (Exception e) {
throw new RuntimeException("Error while getting state", e);
}
}
......@@ -138,8 +162,9 @@ public class StreamingRuntimeContext extends AbstractRuntimeUDFContext {
requireNonNull(name, "The name of the state must not be null");
requireNonNull(stateType, "The state type information must not be null");
ValueStateDescriptor<S> stateDesc = new ValueStateDescriptor<>(name, defaultState, stateType.createSerializer(getExecutionConfig()));
return getPartitionedState(stateDesc);
ValueStateDescriptor<S> stateProps =
new ValueStateDescriptor<>(name, defaultState, stateType.createSerializer(getExecutionConfig()));
return getState(stateProps);
}
// ------------------ expose (read only) relevant information from the stream config -------- //
......
......@@ -21,7 +21,8 @@ package org.apache.flink.streaming.runtime.operators.windowing;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.TaskInfo;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.state.OperatorState;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
......@@ -521,11 +522,6 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
finalResult.addAll(out2.getElements());
assertEquals(numElements, finalResult.size());
synchronized (lock) {
op.close();
}
op.dispose();
Collections.sort(finalResult);
for (int i = 0; i < numElements; i++) {
assertEquals(i, finalResult.get(i).intValue());
......@@ -761,12 +757,13 @@ public class AccumulatingAlignedProcessingTimeWindowOperatorTest {
// get "volatile" style access to entries
static final Map<Integer, Integer> globalCounts = new ConcurrentHashMap<>();
private OperatorState<Integer> state;
private ValueState<Integer> state;
@Override
public void open(Configuration parameters) {
assertNotNull(getRuntimeContext());
state = getRuntimeContext().getKeyValueState("totalCount", Integer.class, 0);
state = getRuntimeContext().getState(
new ValueStateDescriptor<>("totalCount", 0, IntSerializer.INSTANCE));
}
@Override
......
......@@ -23,7 +23,8 @@ import org.apache.flink.api.common.TaskInfo;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichReduceFunction;
import org.apache.flink.api.common.state.OperatorState;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
......@@ -934,7 +935,7 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
static final Map<Integer, Integer> globalCounts = new ConcurrentHashMap<>();
private OperatorState<Integer> state;
private ValueState<Integer> state;
@Override
public void open(Configuration parameters) {
......@@ -942,7 +943,8 @@ public class AggregatingAlignedProcessingTimeWindowOperatorTest {
// start with one, so the final count is correct and we test that we do not
// initialize with 0 always by default
state = getRuntimeContext().getKeyValueState("totalCount", Integer.class, 1);
state = getRuntimeContext().getState(
new ValueStateDescriptor<>("totalCount", 1, IntSerializer.INSTANCE));
}
@Override
......
......@@ -46,6 +46,6 @@ trait StatefulFunction[I, O, S] extends RichFunction {
override def open(c: Configuration) = {
val info = new ValueStateDescriptor[S]("state", null.asInstanceOf[S], stateSerializer)
state = getRuntimeContext().getPartitionedState[ValueState[S]](info)
state = getRuntimeContext().getState(info)
}
}
......@@ -223,7 +223,7 @@ public class EventTimeWindowCheckpointingITCase extends TestLogger {
public void open(Configuration parameters) {
assertEquals(PARALLELISM, getRuntimeContext().getNumberOfParallelSubtasks());
open = true;
count = getRuntimeContext().getPartitionedState(
count = getRuntimeContext().getState(
new ValueStateDescriptor<>("count", 0, IntSerializer.INSTANCE));
}
......
......@@ -172,7 +172,7 @@ public class PartitionedStateCheckpointingITCase extends StreamFaultToleranceTes
failurePos = (new Random().nextLong() % (failurePosMax - failurePosMin)) + failurePosMin;
count = 0;
sum = getRuntimeContext().getPartitionedState(
sum = getRuntimeContext().getState(
new ValueStateDescriptor<>("my_state", 0L, LongSerializer.INSTANCE));
}
......@@ -201,11 +201,11 @@ public class PartitionedStateCheckpointingITCase extends StreamFaultToleranceTes
@Override
public void open(Configuration parameters) throws IOException {
aCounts = getRuntimeContext().getPartitionedState(
aCounts = getRuntimeContext().getState(
new ValueStateDescriptor<>("a", NonSerializableLong.of(0L),
new KryoSerializer<>(NonSerializableLong.class, new ExecutionConfig())));
bCounts = getRuntimeContext().getPartitionedState(
bCounts = getRuntimeContext().getState(
new ValueStateDescriptor<>("b", 0L, LongSerializer.INSTANCE));
}
......
......@@ -255,7 +255,7 @@ public class StreamCheckpointingITCase extends StreamFaultToleranceTestBase {
failurePos = (new Random().nextLong() % (failurePosMax - failurePosMin)) + failurePosMin;
count = 0;
pCount = getRuntimeContext().getPartitionedState(
pCount = getRuntimeContext().getState(
new ValueStateDescriptor<>("pCount", 0L, LongSerializer.INSTANCE));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册