提交 addd0842 编写于 作者: S Stefan Richter 提交者: Aljoscha Krettek

[FLINK-3761] Refactor RocksDB Backend/Make Key-Group Aware

This change makes the RocksDB backend key-group aware by building on the
changes in the previous commit.
上级 4809f536
...@@ -20,8 +20,11 @@ package org.apache.flink.contrib.streaming.state; ...@@ -20,8 +20,11 @@ package org.apache.flink.contrib.streaming.state;
import org.apache.flink.api.common.state.State; import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.StateDescriptor; import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
import org.apache.flink.runtime.state.KvState; import org.apache.flink.runtime.state.KvState;
import org.apache.flink.util.Preconditions; import org.apache.flink.util.Preconditions;
import org.rocksdb.ColumnFamilyHandle; import org.rocksdb.ColumnFamilyHandle;
...@@ -30,7 +33,6 @@ import org.rocksdb.WriteOptions; ...@@ -30,7 +33,6 @@ import org.rocksdb.WriteOptions;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
/** /**
...@@ -56,7 +58,7 @@ public abstract class AbstractRocksDBState<K, N, S extends State, SD extends Sta ...@@ -56,7 +58,7 @@ public abstract class AbstractRocksDBState<K, N, S extends State, SD extends Sta
private N currentNamespace; private N currentNamespace;
/** Backend that holds the actual RocksDB instance where we store state */ /** Backend that holds the actual RocksDB instance where we store state */
protected RocksDBKeyedStateBackend backend; protected RocksDBKeyedStateBackend<K> backend;
/** The column family of this particular instance of state */ /** The column family of this particular instance of state */
protected ColumnFamilyHandle columnFamily; protected ColumnFamilyHandle columnFamily;
...@@ -69,14 +71,20 @@ public abstract class AbstractRocksDBState<K, N, S extends State, SD extends Sta ...@@ -69,14 +71,20 @@ public abstract class AbstractRocksDBState<K, N, S extends State, SD extends Sta
*/ */
private final WriteOptions writeOptions; private final WriteOptions writeOptions;
protected final ByteArrayOutputStreamWithPos keySerializationStream;
protected final DataOutputView keySerializationDateDataOutputView;
private final boolean ambiguousKeyPossible;
/** /**
* Creates a new RocksDB backed state. * Creates a new RocksDB backed state.
* @param namespaceSerializer The serializer for the namespace. * @param namespaceSerializer The serializer for the namespace.
*/ */
protected AbstractRocksDBState(ColumnFamilyHandle columnFamily, protected AbstractRocksDBState(
ColumnFamilyHandle columnFamily,
TypeSerializer<N> namespaceSerializer, TypeSerializer<N> namespaceSerializer,
SD stateDesc, SD stateDesc,
RocksDBKeyedStateBackend backend) { RocksDBKeyedStateBackend<K> backend) {
this.namespaceSerializer = namespaceSerializer; this.namespaceSerializer = namespaceSerializer;
this.backend = backend; this.backend = backend;
...@@ -85,31 +93,27 @@ public abstract class AbstractRocksDBState<K, N, S extends State, SD extends Sta ...@@ -85,31 +93,27 @@ public abstract class AbstractRocksDBState<K, N, S extends State, SD extends Sta
writeOptions = new WriteOptions(); writeOptions = new WriteOptions();
writeOptions.setDisableWAL(true); writeOptions.setDisableWAL(true);
this.stateDesc = Preconditions.checkNotNull(stateDesc, "State Descriptor"); this.stateDesc = Preconditions.checkNotNull(stateDesc, "State Descriptor");
this.keySerializationStream = new ByteArrayOutputStreamWithPos(128);
this.keySerializationDateDataOutputView = new DataOutputViewStreamWrapper(keySerializationStream);
this.ambiguousKeyPossible = (backend.getKeySerializer().getLength() < 0)
&& (namespaceSerializer.getLength() < 0);
} }
// ------------------------------------------------------------------------ // ------------------------------------------------------------------------
@Override @Override
public void clear() { public void clear() {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputViewStreamWrapper out = new DataOutputViewStreamWrapper(baos);
try { try {
writeKeyAndNamespace(out); writeCurrentKeyWithGroupAndNamespace();
byte[] key = baos.toByteArray(); byte[] key = keySerializationStream.toByteArray();
backend.db.remove(columnFamily, writeOptions, key); backend.db.remove(columnFamily, writeOptions, key);
} catch (IOException|RocksDBException e) { } catch (IOException|RocksDBException e) {
throw new RuntimeException("Error while removing entry from RocksDB", e); throw new RuntimeException("Error while removing entry from RocksDB", e);
} }
} }
protected void writeKeyAndNamespace(DataOutputView out) throws IOException {
backend.getKeySerializer().serialize(backend.getCurrentKey(), out);
out.writeByte(42);
namespaceSerializer.serialize(currentNamespace, out);
}
@Override @Override
public void setCurrentNamespace(N namespace) { public void setCurrentNamespace(N namespace) {
this.currentNamespace = Preconditions.checkNotNull(namespace, "Namespace"); this.currentNamespace = Preconditions.checkNotNull(namespace, "Namespace");
...@@ -118,17 +122,67 @@ public abstract class AbstractRocksDBState<K, N, S extends State, SD extends Sta ...@@ -118,17 +122,67 @@ public abstract class AbstractRocksDBState<K, N, S extends State, SD extends Sta
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public byte[] getSerializedValue(byte[] serializedKeyAndNamespace) throws Exception { public byte[] getSerializedValue(byte[] serializedKeyAndNamespace) throws Exception {
// Serialized key and namespace is expected to be of the same format
// as writeKeyAndNamespace()
Preconditions.checkNotNull(serializedKeyAndNamespace, "Serialized key and namespace"); Preconditions.checkNotNull(serializedKeyAndNamespace, "Serialized key and namespace");
byte[] value = backend.db.get(columnFamily, serializedKeyAndNamespace); //TODO make KvStateRequestSerializer key-group aware to save this round trip and key-group computation
Tuple2<K, N> des = KvStateRequestSerializer.<K, N>deserializeKeyAndNamespace(
serializedKeyAndNamespace,
backend.getKeySerializer(),
namespaceSerializer);
int keyGroup = backend.getKeyGroupAssigner().getKeyGroupIndex(des.f0);
writeKeyWithGroupAndNamespace(keyGroup, des.f0, des.f1);
return backend.db.get(columnFamily, keySerializationStream.toByteArray());
}
protected void writeCurrentKeyWithGroupAndNamespace() throws IOException {
writeKeyWithGroupAndNamespace(backend.getCurrentKeyGroupIndex(), backend.getCurrentKey(), currentNamespace);
}
protected void writeKeyWithGroupAndNamespace(int keyGroup, K key, N namespace) throws IOException {
keySerializationStream.reset();
writeKeyGroup(keyGroup);
writeKey(key);
writeNameSpace(namespace);
}
private void writeKeyGroup(int keyGroup) throws IOException {
for (int i = backend.getKeyGroupPrefixBytes(); --i >= 0;) {
keySerializationDateDataOutputView.writeByte(keyGroup >>> (i << 3));
}
}
private void writeKey(K key) throws IOException {
//write key
int beforeWrite = (int) keySerializationStream.getPosition();
backend.getKeySerializer().serialize(key, keySerializationDateDataOutputView);
if (ambiguousKeyPossible) {
//write size of key
writeLengthFrom(beforeWrite);
}
}
private void writeNameSpace(N namespace) throws IOException {
int beforeWrite = (int) keySerializationStream.getPosition();
namespaceSerializer.serialize(namespace, keySerializationDateDataOutputView);
if (value != null) { if (ambiguousKeyPossible) {
return value; //write length of namespace
} else { writeLengthFrom(beforeWrite);
return null;
} }
} }
private void writeLengthFrom(int fromPosition) throws IOException {
int length = (int) (keySerializationStream.getPosition() - fromPosition);
writeVariableIntBytes(length);
}
private void writeVariableIntBytes(int value) throws IOException {
do {
keySerializationDateDataOutputView.writeByte(value);
value >>>= 8;
} while (value != 0);
}
} }
...@@ -29,7 +29,6 @@ import org.rocksdb.RocksDBException; ...@@ -29,7 +29,6 @@ import org.rocksdb.RocksDBException;
import org.rocksdb.WriteOptions; import org.rocksdb.WriteOptions;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
/** /**
...@@ -66,7 +65,7 @@ public class RocksDBFoldingState<K, N, T, ACC> ...@@ -66,7 +65,7 @@ public class RocksDBFoldingState<K, N, T, ACC>
public RocksDBFoldingState(ColumnFamilyHandle columnFamily, public RocksDBFoldingState(ColumnFamilyHandle columnFamily,
TypeSerializer<N> namespaceSerializer, TypeSerializer<N> namespaceSerializer,
FoldingStateDescriptor<T, ACC> stateDesc, FoldingStateDescriptor<T, ACC> stateDesc,
RocksDBKeyedStateBackend backend) { RocksDBKeyedStateBackend<K> backend) {
super(columnFamily, namespaceSerializer, stateDesc, backend); super(columnFamily, namespaceSerializer, stateDesc, backend);
...@@ -79,11 +78,9 @@ public class RocksDBFoldingState<K, N, T, ACC> ...@@ -79,11 +78,9 @@ public class RocksDBFoldingState<K, N, T, ACC>
@Override @Override
public ACC get() { public ACC get() {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputViewStreamWrapper out = new DataOutputViewStreamWrapper(baos);
try { try {
writeKeyAndNamespace(out); writeCurrentKeyWithGroupAndNamespace();
byte[] key = baos.toByteArray(); byte[] key = keySerializationStream.toByteArray();
byte[] valueBytes = backend.db.get(columnFamily, key); byte[] valueBytes = backend.db.get(columnFamily, key);
if (valueBytes == null) { if (valueBytes == null) {
return null; return null;
...@@ -96,23 +93,21 @@ public class RocksDBFoldingState<K, N, T, ACC> ...@@ -96,23 +93,21 @@ public class RocksDBFoldingState<K, N, T, ACC>
@Override @Override
public void add(T value) throws IOException { public void add(T value) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputViewStreamWrapper out = new DataOutputViewStreamWrapper(baos);
try { try {
writeKeyAndNamespace(out); writeCurrentKeyWithGroupAndNamespace();
byte[] key = baos.toByteArray(); byte[] key = keySerializationStream.toByteArray();
byte[] valueBytes = backend.db.get(columnFamily, key); byte[] valueBytes = backend.db.get(columnFamily, key);
DataOutputViewStreamWrapper out = new DataOutputViewStreamWrapper(keySerializationStream);
if (valueBytes == null) { if (valueBytes == null) {
baos.reset(); keySerializationStream.reset();
valueSerializer.serialize(foldFunction.fold(stateDesc.getDefaultValue(), value), out); valueSerializer.serialize(foldFunction.fold(stateDesc.getDefaultValue(), value), out);
backend.db.put(columnFamily, writeOptions, key, baos.toByteArray()); backend.db.put(columnFamily, writeOptions, key, keySerializationStream.toByteArray());
} else { } else {
ACC oldValue = valueSerializer.deserialize(new DataInputViewStreamWrapper(new ByteArrayInputStream(valueBytes))); ACC oldValue = valueSerializer.deserialize(new DataInputViewStreamWrapper(new ByteArrayInputStream(valueBytes)));
ACC newValue = foldFunction.fold(oldValue, value); ACC newValue = foldFunction.fold(oldValue, value);
baos.reset(); keySerializationStream.reset();
valueSerializer.serialize(newValue, out); valueSerializer.serialize(newValue, out);
backend.db.put(columnFamily, writeOptions, key, baos.toByteArray()); backend.db.put(columnFamily, writeOptions, key, keySerializationStream.toByteArray());
} }
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException("Error while adding data to RocksDB", e); throw new RuntimeException("Error while adding data to RocksDB", e);
......
...@@ -28,7 +28,6 @@ import org.rocksdb.RocksDBException; ...@@ -28,7 +28,6 @@ import org.rocksdb.RocksDBException;
import org.rocksdb.WriteOptions; import org.rocksdb.WriteOptions;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
...@@ -67,7 +66,7 @@ public class RocksDBListState<K, N, V> ...@@ -67,7 +66,7 @@ public class RocksDBListState<K, N, V>
public RocksDBListState(ColumnFamilyHandle columnFamily, public RocksDBListState(ColumnFamilyHandle columnFamily,
TypeSerializer<N> namespaceSerializer, TypeSerializer<N> namespaceSerializer,
ListStateDescriptor<V> stateDesc, ListStateDescriptor<V> stateDesc,
RocksDBKeyedStateBackend backend) { RocksDBKeyedStateBackend<K> backend) {
super(columnFamily, namespaceSerializer, stateDesc, backend); super(columnFamily, namespaceSerializer, stateDesc, backend);
this.valueSerializer = stateDesc.getSerializer(); this.valueSerializer = stateDesc.getSerializer();
...@@ -78,11 +77,9 @@ public class RocksDBListState<K, N, V> ...@@ -78,11 +77,9 @@ public class RocksDBListState<K, N, V>
@Override @Override
public Iterable<V> get() { public Iterable<V> get() {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputViewStreamWrapper out = new DataOutputViewStreamWrapper(baos);
try { try {
writeKeyAndNamespace(out); writeCurrentKeyWithGroupAndNamespace();
byte[] key = baos.toByteArray(); byte[] key = keySerializationStream.toByteArray();
byte[] valueBytes = backend.db.get(columnFamily, key); byte[] valueBytes = backend.db.get(columnFamily, key);
if (valueBytes == null) { if (valueBytes == null) {
...@@ -107,16 +104,13 @@ public class RocksDBListState<K, N, V> ...@@ -107,16 +104,13 @@ public class RocksDBListState<K, N, V>
@Override @Override
public void add(V value) throws IOException { public void add(V value) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputViewStreamWrapper out = new DataOutputViewStreamWrapper(baos);
try { try {
writeKeyAndNamespace(out); writeCurrentKeyWithGroupAndNamespace();
byte[] key = baos.toByteArray(); byte[] key = keySerializationStream.toByteArray();
keySerializationStream.reset();
baos.reset(); DataOutputViewStreamWrapper out = new DataOutputViewStreamWrapper(keySerializationStream);
valueSerializer.serialize(value, out); valueSerializer.serialize(value, out);
backend.db.merge(columnFamily, writeOptions, key, baos.toByteArray()); backend.db.merge(columnFamily, writeOptions, key, keySerializationStream.toByteArray());
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException("Error while adding data to RocksDB", e); throw new RuntimeException("Error while adding data to RocksDB", e);
......
...@@ -29,7 +29,6 @@ import org.rocksdb.RocksDBException; ...@@ -29,7 +29,6 @@ import org.rocksdb.RocksDBException;
import org.rocksdb.WriteOptions; import org.rocksdb.WriteOptions;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
/** /**
...@@ -65,7 +64,7 @@ public class RocksDBReducingState<K, N, V> ...@@ -65,7 +64,7 @@ public class RocksDBReducingState<K, N, V>
public RocksDBReducingState(ColumnFamilyHandle columnFamily, public RocksDBReducingState(ColumnFamilyHandle columnFamily,
TypeSerializer<N> namespaceSerializer, TypeSerializer<N> namespaceSerializer,
ReducingStateDescriptor<V> stateDesc, ReducingStateDescriptor<V> stateDesc,
RocksDBKeyedStateBackend backend) { RocksDBKeyedStateBackend<K> backend) {
super(columnFamily, namespaceSerializer, stateDesc, backend); super(columnFamily, namespaceSerializer, stateDesc, backend);
this.valueSerializer = stateDesc.getSerializer(); this.valueSerializer = stateDesc.getSerializer();
...@@ -77,11 +76,9 @@ public class RocksDBReducingState<K, N, V> ...@@ -77,11 +76,9 @@ public class RocksDBReducingState<K, N, V>
@Override @Override
public V get() { public V get() {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputViewStreamWrapper out = new DataOutputViewStreamWrapper(baos);
try { try {
writeKeyAndNamespace(out); writeCurrentKeyWithGroupAndNamespace();
byte[] key = baos.toByteArray(); byte[] key = keySerializationStream.toByteArray();
byte[] valueBytes = backend.db.get(columnFamily, key); byte[] valueBytes = backend.db.get(columnFamily, key);
if (valueBytes == null) { if (valueBytes == null) {
return null; return null;
...@@ -94,23 +91,22 @@ public class RocksDBReducingState<K, N, V> ...@@ -94,23 +91,22 @@ public class RocksDBReducingState<K, N, V>
@Override @Override
public void add(V value) throws IOException { public void add(V value) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputViewStreamWrapper out = new DataOutputViewStreamWrapper(baos);
try { try {
writeKeyAndNamespace(out); writeCurrentKeyWithGroupAndNamespace();
byte[] key = baos.toByteArray(); byte[] key = keySerializationStream.toByteArray();
byte[] valueBytes = backend.db.get(columnFamily, key); byte[] valueBytes = backend.db.get(columnFamily, key);
DataOutputViewStreamWrapper out = new DataOutputViewStreamWrapper(keySerializationStream);
if (valueBytes == null) { if (valueBytes == null) {
baos.reset(); keySerializationStream.reset();
valueSerializer.serialize(value, out); valueSerializer.serialize(value, out);
backend.db.put(columnFamily, writeOptions, key, baos.toByteArray()); backend.db.put(columnFamily, writeOptions, key, keySerializationStream.toByteArray());
} else { } else {
V oldValue = valueSerializer.deserialize(new DataInputViewStreamWrapper(new ByteArrayInputStream(valueBytes))); V oldValue = valueSerializer.deserialize(new DataInputViewStreamWrapper(new ByteArrayInputStream(valueBytes)));
V newValue = reduceFunction.reduce(oldValue, value); V newValue = reduceFunction.reduce(oldValue, value);
baos.reset(); keySerializationStream.reset();
valueSerializer.serialize(newValue, out); valueSerializer.serialize(newValue, out);
backend.db.put(columnFamily, writeOptions, key, baos.toByteArray()); backend.db.put(columnFamily, writeOptions, key, keySerializationStream.toByteArray());
} }
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException("Error while adding data to RocksDB", e); throw new RuntimeException("Error while adding data to RocksDB", e);
......
...@@ -221,12 +221,6 @@ public class RocksDBStateBackend extends AbstractStateBackend { ...@@ -221,12 +221,6 @@ public class RocksDBStateBackend extends AbstractStateBackend {
@Override @Override
public CheckpointStreamFactory createStreamFactory(JobID jobId, public CheckpointStreamFactory createStreamFactory(JobID jobId,
String operatorIdentifier) throws IOException { String operatorIdentifier) throws IOException {
return null;
}
if (fullyAsyncBackup) {
return performFullyAsyncSnapshot(checkpointId, timestamp);
} else {
return checkpointStreamBackend.createStreamFactory(jobId, operatorIdentifier); return checkpointStreamBackend.createStreamFactory(jobId, operatorIdentifier);
} }
...@@ -261,10 +255,24 @@ public class RocksDBStateBackend extends AbstractStateBackend { ...@@ -261,10 +255,24 @@ public class RocksDBStateBackend extends AbstractStateBackend {
String operatorIdentifier, String operatorIdentifier,
TypeSerializer<K> keySerializer, TypeSerializer<K> keySerializer,
KeyGroupAssigner<K> keyGroupAssigner, KeyGroupAssigner<K> keyGroupAssigner,
KeyGroupRange keyGroupRange, KeyGroupRange keyGroupRange,
List<KeyGroupsStateHandle> restoredState, List<KeyGroupsStateHandle> restoredState,
TaskKvStateRegistry kvStateRegistry) throws Exception { TaskKvStateRegistry kvStateRegistry) throws Exception {
throw new RuntimeException("Not implemented.");
lazyInitializeForJob(env, operatorIdentifier);
File instanceBasePath = new File(getDbPath(), UUID.randomUUID().toString());
return new RocksDBKeyedStateBackend<>(
jobID,
operatorIdentifier,
instanceBasePath,
getDbOptions(),
getColumnOptions(),
kvStateRegistry,
keySerializer,
keyGroupAssigner,
keyGroupRange,
restoredState);
} }
// ------------------------------------------------------------------------ // ------------------------------------------------------------------------
......
...@@ -24,13 +24,11 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; ...@@ -24,13 +24,11 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer; import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
import org.apache.flink.util.Preconditions;
import org.rocksdb.ColumnFamilyHandle; import org.rocksdb.ColumnFamilyHandle;
import org.rocksdb.RocksDBException; import org.rocksdb.RocksDBException;
import org.rocksdb.WriteOptions; import org.rocksdb.WriteOptions;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
/** /**
...@@ -63,7 +61,7 @@ public class RocksDBValueState<K, N, V> ...@@ -63,7 +61,7 @@ public class RocksDBValueState<K, N, V>
public RocksDBValueState(ColumnFamilyHandle columnFamily, public RocksDBValueState(ColumnFamilyHandle columnFamily,
TypeSerializer<N> namespaceSerializer, TypeSerializer<N> namespaceSerializer,
ValueStateDescriptor<V> stateDesc, ValueStateDescriptor<V> stateDesc,
RocksDBKeyedStateBackend backend) { RocksDBKeyedStateBackend<K> backend) {
super(columnFamily, namespaceSerializer, stateDesc, backend); super(columnFamily, namespaceSerializer, stateDesc, backend);
this.valueSerializer = stateDesc.getSerializer(); this.valueSerializer = stateDesc.getSerializer();
...@@ -74,11 +72,9 @@ public class RocksDBValueState<K, N, V> ...@@ -74,11 +72,9 @@ public class RocksDBValueState<K, N, V>
@Override @Override
public V value() { public V value() {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputViewStreamWrapper out = new DataOutputViewStreamWrapper(baos);
try { try {
writeKeyAndNamespace(out); writeCurrentKeyWithGroupAndNamespace();
byte[] key = baos.toByteArray(); byte[] key = keySerializationStream.toByteArray();
byte[] valueBytes = backend.db.get(columnFamily, key); byte[] valueBytes = backend.db.get(columnFamily, key);
if (valueBytes == null) { if (valueBytes == null) {
return stateDesc.getDefaultValue(); return stateDesc.getDefaultValue();
...@@ -95,14 +91,13 @@ public class RocksDBValueState<K, N, V> ...@@ -95,14 +91,13 @@ public class RocksDBValueState<K, N, V>
clear(); clear();
return; return;
} }
ByteArrayOutputStream baos = new ByteArrayOutputStream(); DataOutputViewStreamWrapper out = new DataOutputViewStreamWrapper(keySerializationStream);
DataOutputViewStreamWrapper out = new DataOutputViewStreamWrapper(baos);
try { try {
writeKeyAndNamespace(out); writeCurrentKeyWithGroupAndNamespace();
byte[] key = baos.toByteArray(); byte[] key = keySerializationStream.toByteArray();
baos.reset(); keySerializationStream.reset();
valueSerializer.serialize(value, out); valueSerializer.serialize(value, out);
backend.db.put(columnFamily, writeOptions, key, baos.toByteArray()); backend.db.put(columnFamily, writeOptions, key, keySerializationStream.toByteArray());
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException("Error while adding data to RocksDB", e); throw new RuntimeException("Error while adding data to RocksDB", e);
} }
...@@ -110,11 +105,7 @@ public class RocksDBValueState<K, N, V> ...@@ -110,11 +105,7 @@ public class RocksDBValueState<K, N, V>
@Override @Override
public byte[] getSerializedValue(byte[] serializedKeyAndNamespace) throws Exception { public byte[] getSerializedValue(byte[] serializedKeyAndNamespace) throws Exception {
// Serialized key and namespace is expected to be of the same format byte[] value = super.getSerializedValue(serializedKeyAndNamespace);
// as writeKeyAndNamespace()
Preconditions.checkNotNull(serializedKeyAndNamespace, "Serialized key and namespace");
byte[] value = backend.db.get(columnFamily, serializedKeyAndNamespace);
if (value != null) { if (value != null) {
return value; return value;
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
package org.apache.flink.contrib.streaming.state; package org.apache.flink.contrib.streaming.state;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.state.ValueStateDescriptor;
...@@ -28,17 +29,20 @@ import org.apache.flink.configuration.ConfigConstants; ...@@ -28,17 +29,20 @@ import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
import org.apache.flink.runtime.state.VoidNamespace;
import org.apache.flink.runtime.state.VoidNamespaceSerializer;
import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.ChainedStateHandle;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.KeyGroupsStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.VoidNamespace;
import org.apache.flink.runtime.state.VoidNamespaceSerializer;
import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
import org.apache.flink.runtime.state.memory.MemoryStateBackend; import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.AsynchronousException;
import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask; import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
import org.apache.flink.streaming.runtime.tasks.OneInputStreamTaskTestHarness; import org.apache.flink.streaming.runtime.tasks.OneInputStreamTaskTestHarness;
import org.apache.flink.streaming.runtime.tasks.StreamMockEnvironment; import org.apache.flink.streaming.runtime.tasks.StreamMockEnvironment;
...@@ -47,9 +51,9 @@ import org.apache.flink.util.OperatingSystem; ...@@ -47,9 +51,9 @@ import org.apache.flink.util.OperatingSystem;
import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.LocalFileSystem; import org.apache.hadoop.fs.LocalFileSystem;
import org.junit.Assert;
import org.junit.Assume; import org.junit.Assume;
import org.junit.Before; import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.powermock.api.mockito.PowerMockito; import org.powermock.api.mockito.PowerMockito;
...@@ -58,13 +62,15 @@ import org.powermock.core.classloader.annotations.PrepareForTest; ...@@ -58,13 +62,15 @@ import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner; import org.powermock.modules.junit4.PowerMockRunner;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.net.URI; import java.net.URI;
import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.CancellationException;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/** /**
* Tests for asynchronous RocksDB Key/Value state checkpoints. * Tests for asynchronous RocksDB Key/Value state checkpoints.
...@@ -73,7 +79,7 @@ import static org.junit.Assert.assertTrue; ...@@ -73,7 +79,7 @@ import static org.junit.Assert.assertTrue;
@PrepareForTest({ResultPartitionWriter.class, FileSystem.class}) @PrepareForTest({ResultPartitionWriter.class, FileSystem.class})
@PowerMockIgnore({"javax.management.*", "com.sun.jndi.*"}) @PowerMockIgnore({"javax.management.*", "com.sun.jndi.*"})
@SuppressWarnings("serial") @SuppressWarnings("serial")
public class RocksDBAsyncKVSnapshotTest { public class RocksDBAsyncSnapshotTest {
@Before @Before
public void checkOperatingSystem() { public void checkOperatingSystem() {
...@@ -88,14 +94,12 @@ public class RocksDBAsyncKVSnapshotTest { ...@@ -88,14 +94,12 @@ public class RocksDBAsyncKVSnapshotTest {
* test will simply lock forever. * test will simply lock forever.
*/ */
@Test @Test
public void testAsyncCheckpoints() throws Exception { public void testFullyAsyncSnapshot() throws Exception {
LocalFileSystem localFS = new LocalFileSystem(); LocalFileSystem localFS = new LocalFileSystem();
localFS.initialize(new URI("file:///"), new Configuration()); localFS.initialize(new URI("file:///"), new Configuration());
PowerMockito.stub(PowerMockito.method(FileSystem.class, "get", URI.class, Configuration.class)).toReturn(localFS); PowerMockito.stub(PowerMockito.method(FileSystem.class, "get", URI.class, Configuration.class)).toReturn(localFS);
final OneShotLatch delayCheckpointLatch = new OneShotLatch();
final OneShotLatch ensureCheckpointLatch = new OneShotLatch();
final OneInputStreamTask<String, String> task = new OneInputStreamTask<>(); final OneInputStreamTask<String, String> task = new OneInputStreamTask<>();
final OneInputStreamTaskTestHarness<String, String> testHarness = new OneInputStreamTaskTestHarness<>(task, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO); final OneInputStreamTaskTestHarness<String, String> testHarness = new OneInputStreamTaskTestHarness<>(task, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO);
...@@ -119,12 +123,15 @@ public class RocksDBAsyncKVSnapshotTest { ...@@ -119,12 +123,15 @@ public class RocksDBAsyncKVSnapshotTest {
streamConfig.setStreamOperator(new AsyncCheckpointOperator()); streamConfig.setStreamOperator(new AsyncCheckpointOperator());
final OneShotLatch delayCheckpointLatch = new OneShotLatch();
final OneShotLatch ensureCheckpointLatch = new OneShotLatch();
StreamMockEnvironment mockEnv = new StreamMockEnvironment( StreamMockEnvironment mockEnv = new StreamMockEnvironment(
testHarness.jobConfig, testHarness.jobConfig,
testHarness.taskConfig, testHarness.taskConfig,
testHarness.memorySize, testHarness.memorySize,
new MockInputSplitProvider(), new MockInputSplitProvider(),
testHarness.bufferSize) { testHarness.bufferSize) {
@Override @Override
public void acknowledgeCheckpoint(long checkpointId) { public void acknowledgeCheckpoint(long checkpointId) {
...@@ -133,8 +140,8 @@ public class RocksDBAsyncKVSnapshotTest { ...@@ -133,8 +140,8 @@ public class RocksDBAsyncKVSnapshotTest {
@Override @Override
public void acknowledgeCheckpoint(long checkpointId, public void acknowledgeCheckpoint(long checkpointId,
ChainedStateHandle<StreamStateHandle> chainedStateHandle, ChainedStateHandle<StreamStateHandle> chainedStateHandle,
List<KeyGroupsStateHandle> keyGroupStateHandles) { List<KeyGroupsStateHandle> keyGroupStateHandles) {
super.acknowledgeCheckpoint(checkpointId, chainedStateHandle, keyGroupStateHandles); super.acknowledgeCheckpoint(checkpointId, chainedStateHandle, keyGroupStateHandles);
// block on the latch, to verify that triggerCheckpoint returns below, // block on the latch, to verify that triggerCheckpoint returns below,
...@@ -145,9 +152,7 @@ public class RocksDBAsyncKVSnapshotTest { ...@@ -145,9 +152,7 @@ public class RocksDBAsyncKVSnapshotTest {
e.printStackTrace(); e.printStackTrace();
} }
// should be only one k/v state // should be only one k/v state
assertEquals(1, keyGroupStateHandles.size()); assertEquals(1, keyGroupStateHandles.size());
// we now know that the checkpoint went through // we now know that the checkpoint went through
...@@ -183,23 +188,20 @@ public class RocksDBAsyncKVSnapshotTest { ...@@ -183,23 +188,20 @@ public class RocksDBAsyncKVSnapshotTest {
} }
/** /**
* This ensures that asynchronous state handles are actually materialized asynchonously. * This tests ensures that canceling of asynchronous snapshots works as expected and does not block.
* * @throws Exception
* <p>We use latches to block at various stages and see if the code still continues through
* the parts that are not asynchronous. If the checkpoint is not done asynchronously the
* test will simply lock forever.
*/ */
@Test @Test
public void testFullyAsyncCheckpoints() throws Exception { public void testCancelFullyAsyncCheckpoints() throws Exception {
LocalFileSystem localFS = new LocalFileSystem(); LocalFileSystem localFS = new LocalFileSystem();
localFS.initialize(new URI("file:///"), new Configuration()); localFS.initialize(new URI("file:///"), new Configuration());
PowerMockito.stub(PowerMockito.method(FileSystem.class, "get", URI.class, Configuration.class)).toReturn(localFS); PowerMockito.stub(PowerMockito.method(FileSystem.class, "get", URI.class, Configuration.class)).toReturn(localFS);
final OneShotLatch delayCheckpointLatch = new OneShotLatch();
final OneShotLatch ensureCheckpointLatch = new OneShotLatch();
final OneInputStreamTask<String, String> task = new OneInputStreamTask<>(); final OneInputStreamTask<String, String> task = new OneInputStreamTask<>();
//ensure that the async threads complete before invoke method of the tasks returns.
task.setThreadPoolTerminationTimeout(Long.MAX_VALUE);
final OneInputStreamTaskTestHarness<String, String> testHarness = new OneInputStreamTaskTestHarness<>(task, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO); final OneInputStreamTaskTestHarness<String, String> testHarness = new OneInputStreamTaskTestHarness<>(task, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO);
testHarness.configureForKeyedStream(new KeySelector<String, String>() { testHarness.configureForKeyedStream(new KeySelector<String, String>() {
...@@ -214,9 +216,10 @@ public class RocksDBAsyncKVSnapshotTest { ...@@ -214,9 +216,10 @@ public class RocksDBAsyncKVSnapshotTest {
File dbDir = new File(new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()), "state"); File dbDir = new File(new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()), "state");
File chkDir = new File(new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()), "snapshots"); File chkDir = new File(new File(ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH, UUID.randomUUID().toString()), "snapshots");
RocksDBStateBackend backend = new RocksDBStateBackend(chkDir.getAbsoluteFile().toURI(), new MemoryStateBackend()); BlockingStreamMemoryStateBackend memoryStateBackend = new BlockingStreamMemoryStateBackend();
RocksDBStateBackend backend = new RocksDBStateBackend(chkDir.getAbsoluteFile().toURI(), memoryStateBackend);
backend.setDbStoragePath(dbDir.getAbsolutePath()); backend.setDbStoragePath(dbDir.getAbsolutePath());
// backend.enableFullyAsyncSnapshots();
streamConfig.setStateBackend(backend); streamConfig.setStateBackend(backend);
...@@ -227,34 +230,7 @@ public class RocksDBAsyncKVSnapshotTest { ...@@ -227,34 +230,7 @@ public class RocksDBAsyncKVSnapshotTest {
testHarness.taskConfig, testHarness.taskConfig,
testHarness.memorySize, testHarness.memorySize,
new MockInputSplitProvider(), new MockInputSplitProvider(),
testHarness.bufferSize) { testHarness.bufferSize);
@Override
public void acknowledgeCheckpoint(long checkpointId) {
super.acknowledgeCheckpoint(checkpointId);
}
@Override
public void acknowledgeCheckpoint(long checkpointId,
ChainedStateHandle<StreamStateHandle> chainedStateHandle,
List<KeyGroupsStateHandle> keyGroupStateHandles) {
super.acknowledgeCheckpoint(checkpointId, chainedStateHandle, keyGroupStateHandles);
// block on the latch, to verify that triggerCheckpoint returns below,
// even though the async checkpoint would not finish
try {
delayCheckpointLatch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
// should be only one k/v state
assertEquals(1, keyGroupStateHandles.size());
// we now know that the checkpoint went through
ensureCheckpointLatch.trigger();
}
};
testHarness.invoke(mockEnv); testHarness.invoke(mockEnv);
...@@ -273,19 +249,110 @@ public class RocksDBAsyncKVSnapshotTest { ...@@ -273,19 +249,110 @@ public class RocksDBAsyncKVSnapshotTest {
task.triggerCheckpoint(42, 17); task.triggerCheckpoint(42, 17);
// now we allow the checkpoint BlockingStreamMemoryStateBackend.waitFirstWriteLatch.await();
delayCheckpointLatch.trigger(); task.cancel();
// wait for the checkpoint to go through BlockingStreamMemoryStateBackend.unblockCancelLatch.trigger();
ensureCheckpointLatch.await();
testHarness.endInput(); testHarness.endInput();
testHarness.waitForTaskCompletion(); try {
testHarness.waitForTaskCompletion();
Assert.fail("Operation completed. Cancel failed.");
} catch (Exception expected) {
// we expect the exception from canceling snapshots
Throwable cause = expected.getCause();
if(cause instanceof AsynchronousException) {
AsynchronousException asynchronousException = (AsynchronousException) cause;
cause = asynchronousException.getCause();
Assert.assertTrue("Unexpected Exception: " + cause,
cause instanceof CancellationException //future canceled
|| cause instanceof InterruptedException); //thread interrupted
} else {
Assert.fail();
}
}
} }
@Test
public void testConsistentSnapshotSerializationFlagsAndMasks() {
Assert.assertEquals(0xFFFF, RocksDBKeyedStateBackend.RocksDBSnapshotOperation.END_OF_KEY_GROUP_MARK);
Assert.assertEquals(0x80, RocksDBKeyedStateBackend.RocksDBSnapshotOperation.FIRST_BIT_IN_BYTE_MASK);
byte[] expectedKey = new byte[] {42, 42};
byte[] modKey = expectedKey.clone();
Assert.assertFalse(RocksDBKeyedStateBackend.RocksDBSnapshotOperation.hasMetaDataFollowsFlag(modKey));
RocksDBKeyedStateBackend.RocksDBSnapshotOperation.setMetaDataFollowsFlagInKey(modKey);
Assert.assertTrue(RocksDBKeyedStateBackend.RocksDBSnapshotOperation.hasMetaDataFollowsFlag(modKey));
RocksDBKeyedStateBackend.RocksDBSnapshotOperation.clearMetaDataFollowsFlag(modKey);
Assert.assertFalse(RocksDBKeyedStateBackend.RocksDBSnapshotOperation.hasMetaDataFollowsFlag(modKey));
Assert.assertTrue(Arrays.equals(expectedKey, modKey));
}
// ------------------------------------------------------------------------ // ------------------------------------------------------------------------
/**
* Creates us a CheckpointStateOutputStream that blocks write ops on a latch to delay writing of snapshots.
*/
static class BlockingStreamMemoryStateBackend extends MemoryStateBackend {
public static OneShotLatch waitFirstWriteLatch = new OneShotLatch();
public static OneShotLatch unblockCancelLatch = new OneShotLatch();
volatile boolean closed = false;
@Override
public CheckpointStreamFactory createStreamFactory(JobID jobId, String operatorIdentifier) throws IOException {
return new MemCheckpointStreamFactory(4 * 1024 * 1024) {
@Override
public CheckpointStateOutputStream createCheckpointStateOutputStream(long checkpointID, long timestamp) throws Exception {
return new MemoryCheckpointOutputStream(4 * 1024 * 1024) {
@Override
public void write(int b) throws IOException {
waitFirstWriteLatch.trigger();
try {
unblockCancelLatch.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
if(closed) {
throw new IOException("Stream closed.");
}
super.write(b);
}
@Override
public void write(byte[] b, int off, int len) throws IOException {
waitFirstWriteLatch.trigger();
try {
unblockCancelLatch.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
if(closed) {
throw new IOException("Stream closed.");
}
super.write(b, off, len);
}
@Override
public void close() {
closed = true;
super.close();
}
};
}
};
}
}
public static class AsyncCheckpointOperator public static class AsyncCheckpointOperator
extends AbstractStreamOperator<String> extends AbstractStreamOperator<String>
implements OneInputStreamOperator<String, String> { implements OneInputStreamOperator<String, String> {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.contrib.streaming.state;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
import org.apache.flink.runtime.testutils.CommonTestUtils;
import org.junit.Assert;
import org.junit.Test;
import org.rocksdb.ColumnFamilyDescriptor;
import org.rocksdb.ColumnFamilyHandle;
import org.rocksdb.RocksDB;
import org.rocksdb.RocksIterator;
import java.io.DataOutputStream;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
public class RocksDBMergeIteratorTest {
private static final int NUM_KEY_VAL_STATES = 50;
private static final int MAX_NUM_KEYS = 20;
@Test
public void testEmptyMergeIterator() throws IOException {
RocksDBKeyedStateBackend.RocksDBMergeIterator emptyIterator =
new RocksDBKeyedStateBackend.RocksDBMergeIterator(Collections.EMPTY_LIST, 2);
Assert.assertFalse(emptyIterator.isValid());
}
@Test
public void testMergeIterator() throws Exception {
Assert.assertTrue(MAX_NUM_KEYS <= Byte.MAX_VALUE);
testMergeIterator(Byte.MAX_VALUE);
testMergeIterator(Short.MAX_VALUE);
}
public void testMergeIterator(int maxParallelism) throws Exception {
Random random = new Random(1234);
File tmpDir = CommonTestUtils.createTempDirectory();
RocksDB rocksDB = RocksDB.open(tmpDir.getAbsolutePath());
try {
List<Tuple2<RocksIterator, Integer>> rocksIteratorsWithKVStateId = new ArrayList<>();
List<Tuple2<ColumnFamilyHandle, Integer>> columnFamilyHandlesWithKeyCount = new ArrayList<>();
int totalKeysExpected = 0;
for (int c = 0; c < NUM_KEY_VAL_STATES; ++c) {
ColumnFamilyHandle handle = rocksDB.createColumnFamily(
new ColumnFamilyDescriptor(("column-" + c).getBytes()));
ByteArrayOutputStreamWithPos bos = new ByteArrayOutputStreamWithPos();
DataOutputStream dos = new DataOutputStream(bos);
int numKeys = random.nextInt(MAX_NUM_KEYS + 1);
for (int i = 0; i < numKeys; ++i) {
if (maxParallelism <= Byte.MAX_VALUE) {
dos.writeByte(i);
} else {
dos.writeShort(i);
}
dos.writeInt(i);
byte[] key = bos.toByteArray();
byte[] val = new byte[]{42};
rocksDB.put(handle, key, val);
bos.reset();
}
columnFamilyHandlesWithKeyCount.add(new Tuple2<>(handle, numKeys));
totalKeysExpected += numKeys;
}
int id = 0;
for (Tuple2<ColumnFamilyHandle, Integer> columnFamilyHandle : columnFamilyHandlesWithKeyCount) {
rocksIteratorsWithKVStateId.add(new Tuple2<>(rocksDB.newIterator(columnFamilyHandle.f0), id));
++id;
}
RocksDBKeyedStateBackend.RocksDBMergeIterator mergeIterator = new RocksDBKeyedStateBackend.RocksDBMergeIterator(rocksIteratorsWithKVStateId, maxParallelism <= Byte.MAX_VALUE ? 1 : 2);
int prevKVState = -1;
int prevKey = -1;
int prevKeyGroup = -1;
int totalKeysActual = 0;
while (mergeIterator.isValid()) {
ByteBuffer bb = ByteBuffer.wrap(mergeIterator.key());
int keyGroup = maxParallelism > Byte.MAX_VALUE ? bb.getShort() : bb.get();
int key = bb.getInt();
Assert.assertTrue(keyGroup >= prevKeyGroup);
Assert.assertTrue(key >= prevKey);
Assert.assertEquals(prevKeyGroup != keyGroup, mergeIterator.isNewKeyGroup());
Assert.assertEquals(prevKVState != mergeIterator.kvStateId(), mergeIterator.isNewKeyValueState());
prevKeyGroup = keyGroup;
prevKVState = mergeIterator.kvStateId();
//System.out.println(keyGroup + " " + key + " " + mergeIterator.kvStateId());
mergeIterator.next();
++totalKeysActual;
}
Assert.assertEquals(totalKeysExpected, totalKeysActual);
for (Tuple2<ColumnFamilyHandle, Integer> handleWithCount : columnFamilyHandlesWithKeyCount) {
rocksDB.dropColumnFamily(handleWithCount.f0);
}
} finally {
rocksDB.close();
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.io.async;
import java.io.Closeable;
import java.io.IOException;
/**
* The abstract class encapsulates the lifecycle and execution strategy for asynchronous IO operations
*
* @param <V> return type of the asynchronous call
* @param <D> type of the IO handle
*/
public abstract class AbstractAsyncIOCallable<V, D extends Closeable> implements StoppableCallbackCallable<V> {
private volatile boolean stopped;
/**
* Closable handle to IO, e.g. an InputStream
*/
private volatile D ioHandle;
/**
* Stores exception that might happen during close
*/
private volatile IOException stopException;
public AbstractAsyncIOCallable() {
this.stopped = false;
}
/**
* This method implements the strategy for the actual IO operation:
*
* 1) Open the IO handle
* 2) Perform IO operation
* 3) Close IO handle
*
* @return Result of the IO operation, e.g. a deserialized object.
* @throws Exception exception that happened during the call.
*/
@Override
public V call() throws Exception {
synchronized (this) {
if (isStopped()) {
throw new IOException("Task was already stopped. No I/O handle opened.");
}
ioHandle = openIOHandle();
}
try {
return performOperation();
} finally {
closeIOHandle();
}
}
/**
* Open the IO Handle (e.g. a stream) on which the operation will be performed.
*
* @return the opened IO handle that implements #Closeable
* @throws Exception
*/
protected abstract D openIOHandle() throws Exception;
/**
* Implements the actual IO operation on the opened IO handle.
*
* @return Result of the IO operation
* @throws Exception
*/
protected abstract V performOperation() throws Exception;
/**
* Stops the I/O operation by closing the I/O handle. If an exception is thrown on close, it can be accessed via
* #getStopException().
*/
@Override
public void stop() {
closeIOHandle();
}
private synchronized void closeIOHandle() {
if (!stopped) {
stopped = true;
final D handle = ioHandle;
if (handle != null) {
try {
handle.close();
} catch (IOException ex) {
stopException = ex;
}
}
}
}
/**
* Returns the IO handle.
* @return the IO handle
*/
protected D getIoHandle() {
return ioHandle;
}
/**
* Optional callback that subclasses can implement. This is called when the callable method completed, e.g. because
* it finished or was stopped.
*/
@Override
public void done() {
//optional callback hook
}
/**
* Check if the IO operation is stopped
*
* @return true if stop() was called
*/
@Override
public boolean isStopped() {
return stopped;
}
/**
* Returns Exception that might happen on stop.
*
* @return Potential Exception that happened open stopping.
*/
@Override
public IOException getStopException() {
return stopException;
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.io.async;
/**
* Callback for an asynchronous operation that is called on termination
*/
public interface AsyncDoneCallback {
/**
* the callback
*/
void done();
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.io.async;
import java.io.IOException;
/**
* An asynchronous operation that can be stopped.
*/
public interface AsyncStoppable {
/**
* Stop the operation
*/
void stop();
/**
* Check whether the operation is stopped
*
* @return true iff operation is stopped
*/
boolean isStopped();
/**
* Delivers Exception that might happen during {@link #stop()}
*
* @return Exception that can happen during stop
*/
IOException getStopException();
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.io.async;
import org.apache.flink.util.Preconditions;
import java.util.concurrent.FutureTask;
/**
* @param <V> return type of the callable function
*/
public class AsyncStoppableTaskWithCallback<V> extends FutureTask<V> {
protected final StoppableCallbackCallable<V> stoppableCallbackCallable;
public AsyncStoppableTaskWithCallback(StoppableCallbackCallable<V> callable) {
super(Preconditions.checkNotNull(callable));
this.stoppableCallbackCallable = callable;
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
if (mayInterruptIfRunning) {
stoppableCallbackCallable.stop();
}
return super.cancel(mayInterruptIfRunning);
}
@Override
protected void done() {
stoppableCallbackCallable.done();
}
public static <V> AsyncStoppableTaskWithCallback<V> from(StoppableCallbackCallable<V> callable) {
return new AsyncStoppableTaskWithCallback<>(callable);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.io.async;
import java.util.concurrent.Callable;
/**
* A {@link Callable} that can be stopped and offers a callback on termination.
*
* @param <V> return value of the call operation.
*/
public interface StoppableCallbackCallable<V> extends Callable<V>, AsyncStoppable, AsyncDoneCallback {
}
...@@ -71,7 +71,7 @@ public class MemCheckpointStreamFactory implements CheckpointStreamFactory { ...@@ -71,7 +71,7 @@ public class MemCheckpointStreamFactory implements CheckpointStreamFactory {
/** /**
* A {@code CheckpointStateOutputStream} that writes into a byte array. * A {@code CheckpointStateOutputStream} that writes into a byte array.
*/ */
public static final class MemoryCheckpointOutputStream extends CheckpointStateOutputStream { public static class MemoryCheckpointOutputStream extends CheckpointStateOutputStream {
private final ByteArrayOutputStreamWithPos os = new ByteArrayOutputStreamWithPos(); private final ByteArrayOutputStreamWithPos os = new ByteArrayOutputStreamWithPos();
...@@ -86,13 +86,13 @@ public class MemCheckpointStreamFactory implements CheckpointStreamFactory { ...@@ -86,13 +86,13 @@ public class MemCheckpointStreamFactory implements CheckpointStreamFactory {
} }
@Override @Override
public void write(int b) { public void write(int b) throws IOException {
os.write(b); os.write(b);
isEmpty = false; isEmpty = false;
} }
@Override @Override
public void write(byte[] b, int off, int len) { public void write(byte[] b, int off, int len) throws IOException {
os.write(b, off, len); os.write(b, off, len);
isEmpty = false; isEmpty = false;
} }
......
...@@ -66,6 +66,7 @@ import java.util.concurrent.Executors; ...@@ -66,6 +66,7 @@ import java.util.concurrent.Executors;
import java.util.concurrent.RunnableFuture; import java.util.concurrent.RunnableFuture;
import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
/** /**
* Base class for all streaming tasks. A task is the unit of local processing that is deployed * Base class for all streaming tasks. A task is the unit of local processing that is deployed
...@@ -176,8 +177,12 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> ...@@ -176,8 +177,12 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
private long lastCheckpointSize = 0; private long lastCheckpointSize = 0;
/** Thread pool for async snapshot workers */
private ExecutorService asyncOperationsThreadPool; private ExecutorService asyncOperationsThreadPool;
/** Timeout to await the termination of the thread pool in milliseconds */
private long threadPoolTerminationTimeout = 0L;
// ------------------------------------------------------------------------ // ------------------------------------------------------------------------
// Life cycle methods for specific implementations // Life cycle methods for specific implementations
// ------------------------------------------------------------------------ // ------------------------------------------------------------------------
...@@ -441,6 +446,10 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> ...@@ -441,6 +446,10 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
if (!asyncOperationsThreadPool.isShutdown()) { if (!asyncOperationsThreadPool.isShutdown()) {
asyncOperationsThreadPool.shutdownNow(); asyncOperationsThreadPool.shutdownNow();
} }
if(threadPoolTerminationTimeout > 0L) {
asyncOperationsThreadPool.awaitTermination(threadPoolTerminationTimeout, TimeUnit.MILLISECONDS);
}
} }
/** /**
...@@ -861,6 +870,15 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>> ...@@ -861,6 +870,15 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
}; };
} }
/**
* Sets a timeout for the async thread pool. Default should always be 0 to avoid blocking restarts of task.
*
* @param threadPoolTerminationTimeout timeout for the async thread pool in milliseconds
*/
public void setThreadPoolTerminationTimeout(long threadPoolTerminationTimeout) {
this.threadPoolTerminationTimeout = threadPoolTerminationTimeout;
}
// ------------------------------------------------------------------------ // ------------------------------------------------------------------------
/** /**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册