提交 108eba3a 编写于 作者: G gyfora 提交者: Stephan Ewen

[streaming] connectWith fix

上级 543a27de
......@@ -74,6 +74,11 @@ public class DataStream<T extends Tuple> {
}
public DataStream<T> batch(int batchSize) {
if (batchSize < 1) {
throw new IllegalArgumentException("Batch size must be positive.");
}
for (int i = 0; i < batchSizes.size(); i++) {
batchSizes.set(i, batchSize);
}
......@@ -85,6 +90,7 @@ public class DataStream<T extends Tuple> {
connectIDs.addAll(stream.connectIDs);
ctypes.addAll(stream.ctypes);
cparams.addAll(stream.cparams);
batchSizes.addAll(stream.batchSizes);
return this;
}
......
......@@ -64,6 +64,8 @@ public class JobGraphBuilder {
protected String maxParallelismVertexName;
protected int maxParallelism;
protected FaultToleranceType faultToleranceType;
private int batchSize;
private long batchTimeout;
/**
* Creates a new JobGraph with the given name
......@@ -97,8 +99,11 @@ public class JobGraphBuilder {
this(jobGraphName, FaultToleranceType.NONE);
}
public JobGraphBuilder(String jobGraphName, FaultToleranceType faultToleranceType, int batchSize) {
public JobGraphBuilder(String jobGraphName, FaultToleranceType faultToleranceType,
int defaultBatchSize, long defaultBatchTimeoutMillis) {
this(jobGraphName, faultToleranceType);
this.batchSize = defaultBatchSize;
this.batchTimeout = defaultBatchTimeoutMillis;
}
/**
......@@ -244,7 +249,6 @@ public class JobGraphBuilder {
* @param component
* AbstractJobVertex associated with the component
*/
private Configuration setComponent(String componentName,
final Class<? extends StreamComponent> InvokableClass, int parallelism,
int subtasksPerInstance, AbstractJobVertex component) {
......@@ -259,6 +263,8 @@ public class JobGraphBuilder {
Configuration config = new TaskConfig(component.getConfiguration()).getConfiguration();
config.setClass("userfunction", InvokableClass);
config.setString("componentName", componentName);
config.setInteger("batchSize", batchSize);
config.setLong("batchTimeout", batchTimeout);
// config.setBytes("operator", getSerializedFunction());
config.setInteger("faultToleranceType", faultToleranceType.id);
......@@ -268,12 +274,6 @@ public class JobGraphBuilder {
return config;
}
public void setBatchSize(String componentName, int batchSize) {
AbstractJobVertex component = components.get(componentName);
Configuration config = component.getConfiguration();
config.setInteger("batchSize", batchSize);
}
private Configuration setComponent(String componentName,
UserSourceInvokable<? extends Tuple> InvokableObject, int parallelism,
int subtasksPerInstance, AbstractJobVertex component) {
......@@ -304,6 +304,11 @@ public class JobGraphBuilder {
return config;
}
public void setBatchSize(String componentName, int batchSize) {
Configuration config = components.get(componentName).getConfiguration();
config.setInteger("batchSize", batchSize);
}
/**
* Adds serialized invokable object to the JobVertex configuration
*
......
......@@ -28,37 +28,54 @@ public class StreamCollector<T extends Tuple> implements Collector<T> {
protected StreamRecord streamRecord;
protected int batchSize;
protected long batchTimeout;
protected int counter = 0;
protected int channelID;
private long timeOfLastRecordEmitted = System.currentTimeMillis();;
private List<RecordWriter<StreamRecord>> outputs;
public StreamCollector(int batchSize, int channelID,
public StreamCollector(int batchSize, long batchTimeout, int channelID,
SerializationDelegate<Tuple> serializationDelegate,
List<RecordWriter<StreamRecord>> outputs) {
this.batchSize = batchSize;
this.batchTimeout = batchTimeout;
this.streamRecord = new ArrayStreamRecord(batchSize);
this.streamRecord.setSeralizationDelegate(serializationDelegate);
this.channelID = channelID;
this.outputs = outputs;
}
public StreamCollector(int batchSize, int channelID,
public StreamCollector(int batchSize, long batchTimeout, int channelID,
SerializationDelegate<Tuple> serializationDelegate) {
this(batchSize, channelID, serializationDelegate, null);
this(batchSize, batchTimeout, channelID, serializationDelegate, null);
}
// TODO reconsider emitting mechanism at timeout (find a place to timeout)
@Override
public void collect(T tuple) {
streamRecord.setTuple(counter, StreamRecord.copyTuple(tuple));
counter++;
if (counter >= batchSize) {
counter = 0;
streamRecord.setId(channelID);
emit(streamRecord);
// timeOfLastRecordEmitted = System.currentTimeMillis();
} else {
// timeout();
}
}
public void timeout() {
if (timeOfLastRecordEmitted + batchTimeout < System.currentTimeMillis()) {
StreamRecord truncatedRecord = new ArrayStreamRecord(streamRecord, counter);
emit(truncatedRecord);
timeOfLastRecordEmitted = System.currentTimeMillis();
}
}
private void emit(StreamRecord streamRecord) {
counter = 0;
streamRecord.setId(channelID);
if (outputs == null) {
System.out.println(streamRecord);
} else {
......
......@@ -33,9 +33,19 @@ import eu.stratosphere.util.Collector;
public class StreamExecutionEnvironment {
JobGraphBuilder jobGraphBuilder;
public StreamExecutionEnvironment() {
jobGraphBuilder = new JobGraphBuilder("jobGraph", FaultToleranceType.NONE);
public StreamExecutionEnvironment(int defaultBatchSize, long defaultBatchTimeoutMillis) {
if (defaultBatchSize < 1) {
throw new IllegalArgumentException("Batch size must be positive.");
}
if (defaultBatchTimeoutMillis < 1) {
throw new IllegalArgumentException("Batch timeout must be positive.");
}
jobGraphBuilder = new JobGraphBuilder("jobGraph", FaultToleranceType.NONE,
defaultBatchSize, defaultBatchTimeoutMillis);
}
public StreamExecutionEnvironment() {
this(1, 1000);
}
private static class DummySource extends UserSourceInvokable<Tuple1<String>> {
......@@ -54,6 +64,7 @@ public class StreamExecutionEnvironment {
}
public <T extends Tuple> void setBatchSize(DataStream<T> inputStream) {
for (int i = 0; i < inputStream.connectIDs.size(); i++) {
jobGraphBuilder.setBatchSize(inputStream.connectIDs.get(i),
inputStream.batchSizes.get(i));
......
......@@ -107,8 +107,11 @@ public final class StreamComponentHelper<T extends AbstractInvokable> {
public StreamCollector<Tuple> setCollector(Configuration taskConfiguration, int id,
List<RecordWriter<StreamRecord>> outputs) {
int batchSize = taskConfiguration.getInteger("batchSize", 1);
collector = new StreamCollector<Tuple>(batchSize, id, outSerializationDelegate, outputs);
long batchTimeout = taskConfiguration.getLong("batchTimeout", 1000);
collector = new StreamCollector<Tuple>(batchSize, batchTimeout, id,
outSerializationDelegate, outputs);
return collector;
}
......
......@@ -45,14 +45,19 @@ public class ArrayStreamRecord extends StreamRecord {
}
public ArrayStreamRecord(StreamRecord record) {
tupleBatch = new Tuple[record.getBatchSize()];
this(record, record.getBatchSize());
}
public ArrayStreamRecord(StreamRecord record, int truncatedSize) {
tupleBatch = new Tuple[truncatedSize];
this.uid = new UID(Arrays.copyOf(record.getId().getId(), 20));
for (int i = 0; i < record.getBatchSize(); ++i) {
for (int i = 0; i < truncatedSize; ++i) {
this.tupleBatch[i] = copyTuple(record.getTuple(i));
}
this.batchSize = tupleBatch.length;
}
/**
* Creates a new batch of records containing the given Tuple array as
* elements
......
......@@ -15,13 +15,10 @@
package eu.stratosphere.streaming.api;
import static org.junit.Assert.fail;
import java.util.Iterator;
import org.junit.Test;
import eu.stratosphere.api.java.functions.FlatMapFunction;
import eu.stratosphere.api.java.functions.GroupReduceFunction;
import eu.stratosphere.api.java.tuple.Tuple1;
import eu.stratosphere.util.Collector;
......@@ -72,8 +69,8 @@ public class BatchReduceTest {
@Test
public void test() throws Exception {
StreamExecutionEnvironment context = new StreamExecutionEnvironment();
DataStream<Tuple1<Double>> dataStream0 = context.addSource(new MySource()).batch(4)
StreamExecutionEnvironment context = new StreamExecutionEnvironment(4, 1000);
DataStream<Tuple1<Double>> dataStream0 = context.addSource(new MySource())
.batchReduce(new MyBatchReduce()).addSink(new MySink());
context.execute();
......
package eu.stratosphere.streaming.api;
import static org.junit.Assert.assertEquals;
import org.junit.Test;
import eu.stratosphere.api.java.functions.FlatMapFunction;
import eu.stratosphere.api.java.tuple.Tuple1;
import eu.stratosphere.util.Collector;
public class BatchTest {
private static int count = 0;
private static final class MySource extends SourceFunction<Tuple1<String>> {
private Tuple1<String> outTuple = new Tuple1<String>();
@Override
public void invoke(Collector<Tuple1<String>> collector) throws Exception {
for (int i = 0; i < 20; i++) {
outTuple.f0 = "string #" + i;
collector.collect(outTuple);
}
}
}
private static final class MyMap extends FlatMapFunction<Tuple1<String>, Tuple1<String>> {
@Override
public void flatMap(Tuple1<String> value, Collector<Tuple1<String>> out) throws Exception {
out.collect(value);
}
}
private static final class MySink extends SinkFunction<Tuple1<String>> {
@Override
public void invoke(Tuple1<String> tuple) {
count++;
}
}
@Test
public void test() throws Exception {
StreamExecutionEnvironment context = new StreamExecutionEnvironment();
DataStream<Tuple1<String>> dataStream = context
.addSource(new MySource())
.flatMap(new MyMap()).batch(4)
.flatMap(new MyMap()).batch(2)
.flatMap(new MyMap()).batch(5)
.flatMap(new MyMap()).batch(4)
.addSink(new MySink());
context.execute();
assertEquals(20, count);
}
}
......@@ -70,7 +70,18 @@ public class FlatMapTest {
@Test
public void test() throws Exception {
StreamExecutionEnvironment context = new StreamExecutionEnvironment();
try {
StreamExecutionEnvironment context2 = new StreamExecutionEnvironment(0, 1000);
fail();
} catch (IllegalArgumentException e) {
try {
StreamExecutionEnvironment context2 = new StreamExecutionEnvironment(1, 0);
fail();
} catch (IllegalArgumentException e2) {
}
}
StreamExecutionEnvironment context = new StreamExecutionEnvironment(2, 1000);
DataStream<Tuple1<String>> dataStream0 = context.addSource(new MySource());
DataStream<Tuple1<String>> dataStream1 = context.addDummySource().connectWith(dataStream0)
......@@ -90,7 +101,7 @@ public class FlatMapTest {
FlatMapFunction<Tuple, Tuple> f = (FlatMapFunction<Tuple, Tuple>) in.readObject();
StreamCollector<Tuple> s = new StreamCollector<Tuple>(1, 1, null);
StreamCollector<Tuple> s = new StreamCollector<Tuple>(1, 1000, 1, null);
Tuple t = new Tuple1<String>("asd");
f.flatMap(t, s);
......
......@@ -72,7 +72,7 @@ public class MapTest {
MapFunction<Tuple, Tuple> f = (MapFunction<Tuple, Tuple>) in.readObject();
StreamCollector<Tuple> s = new StreamCollector<Tuple>(1, 1, null);
StreamCollector<Tuple> s = new StreamCollector<Tuple>(1, 1000, 1, null);
Tuple t = new Tuple1<String>("asd");
s.collect(f.map(t));
......
......@@ -25,13 +25,13 @@ public class StreamCollectorTest {
@Test
public void testStreamCollector() {
StreamCollector collector = new StreamCollector(10, 0, null);
StreamCollector collector = new StreamCollector(10, 1000, 0, null);
assertEquals(10, collector.batchSize);
}
@Test
public void testCollect() {
StreamCollector collector = new StreamCollector(2, 0, null);
StreamCollector collector = new StreamCollector(2, 1000, 0, null);
collector.collect(new Tuple1<Integer>(3));
collector.collect(new Tuple1<Integer>(4));
collector.collect(new Tuple1<Integer>(5));
......@@ -39,6 +39,20 @@ public class StreamCollectorTest {
}
@Test
public void testBatchSize() throws InterruptedException {
System.out.println("---------------");
StreamCollector collector = new StreamCollector(3, 100, 0, null);
collector.collect(new Tuple1<Integer>(0));
collector.collect(new Tuple1<Integer>(0));
collector.collect(new Tuple1<Integer>(0));
Thread.sleep(200);
collector.collect(new Tuple1<Integer>(2));
collector.collect(new Tuple1<Integer>(3));
System.out.println("---------------");
}
@Test
public void testClose() {
}
......
......@@ -97,4 +97,17 @@ public class ArrayStreamRecordTest {
}
@Test
public void truncatedSizeTest() {
StreamRecord record = new ArrayStreamRecord(4);
record.setTuple(0, new Tuple1<Integer>(0));
record.setTuple(1, new Tuple1<Integer>(1));
record.setTuple(2, new Tuple1<Integer>(2));
record.setTuple(3, new Tuple1<Integer>(3));
StreamRecord truncatedRecord = new ArrayStreamRecord(record, 2);
assertEquals(2, truncatedRecord.batchSize);
assertEquals(0, truncatedRecord.getTuple(0).getField(0));
assertEquals(1, truncatedRecord.getTuple(1).getField(0));
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册