提交 108eba3a 编写于 作者: G gyfora 提交者: Stephan Ewen

[streaming] connectWith fix

上级 543a27de
...@@ -74,6 +74,11 @@ public class DataStream<T extends Tuple> { ...@@ -74,6 +74,11 @@ public class DataStream<T extends Tuple> {
} }
public DataStream<T> batch(int batchSize) { public DataStream<T> batch(int batchSize) {
if (batchSize < 1) {
throw new IllegalArgumentException("Batch size must be positive.");
}
for (int i = 0; i < batchSizes.size(); i++) { for (int i = 0; i < batchSizes.size(); i++) {
batchSizes.set(i, batchSize); batchSizes.set(i, batchSize);
} }
...@@ -85,6 +90,7 @@ public class DataStream<T extends Tuple> { ...@@ -85,6 +90,7 @@ public class DataStream<T extends Tuple> {
connectIDs.addAll(stream.connectIDs); connectIDs.addAll(stream.connectIDs);
ctypes.addAll(stream.ctypes); ctypes.addAll(stream.ctypes);
cparams.addAll(stream.cparams); cparams.addAll(stream.cparams);
batchSizes.addAll(stream.batchSizes);
return this; return this;
} }
......
...@@ -64,6 +64,8 @@ public class JobGraphBuilder { ...@@ -64,6 +64,8 @@ public class JobGraphBuilder {
protected String maxParallelismVertexName; protected String maxParallelismVertexName;
protected int maxParallelism; protected int maxParallelism;
protected FaultToleranceType faultToleranceType; protected FaultToleranceType faultToleranceType;
private int batchSize;
private long batchTimeout;
/** /**
* Creates a new JobGraph with the given name * Creates a new JobGraph with the given name
...@@ -97,8 +99,11 @@ public class JobGraphBuilder { ...@@ -97,8 +99,11 @@ public class JobGraphBuilder {
this(jobGraphName, FaultToleranceType.NONE); this(jobGraphName, FaultToleranceType.NONE);
} }
public JobGraphBuilder(String jobGraphName, FaultToleranceType faultToleranceType, int batchSize) { public JobGraphBuilder(String jobGraphName, FaultToleranceType faultToleranceType,
int defaultBatchSize, long defaultBatchTimeoutMillis) {
this(jobGraphName, faultToleranceType); this(jobGraphName, faultToleranceType);
this.batchSize = defaultBatchSize;
this.batchTimeout = defaultBatchTimeoutMillis;
} }
/** /**
...@@ -244,7 +249,6 @@ public class JobGraphBuilder { ...@@ -244,7 +249,6 @@ public class JobGraphBuilder {
* @param component * @param component
* AbstractJobVertex associated with the component * AbstractJobVertex associated with the component
*/ */
private Configuration setComponent(String componentName, private Configuration setComponent(String componentName,
final Class<? extends StreamComponent> InvokableClass, int parallelism, final Class<? extends StreamComponent> InvokableClass, int parallelism,
int subtasksPerInstance, AbstractJobVertex component) { int subtasksPerInstance, AbstractJobVertex component) {
...@@ -259,6 +263,8 @@ public class JobGraphBuilder { ...@@ -259,6 +263,8 @@ public class JobGraphBuilder {
Configuration config = new TaskConfig(component.getConfiguration()).getConfiguration(); Configuration config = new TaskConfig(component.getConfiguration()).getConfiguration();
config.setClass("userfunction", InvokableClass); config.setClass("userfunction", InvokableClass);
config.setString("componentName", componentName); config.setString("componentName", componentName);
config.setInteger("batchSize", batchSize);
config.setLong("batchTimeout", batchTimeout);
// config.setBytes("operator", getSerializedFunction()); // config.setBytes("operator", getSerializedFunction());
config.setInteger("faultToleranceType", faultToleranceType.id); config.setInteger("faultToleranceType", faultToleranceType.id);
...@@ -268,12 +274,6 @@ public class JobGraphBuilder { ...@@ -268,12 +274,6 @@ public class JobGraphBuilder {
return config; return config;
} }
public void setBatchSize(String componentName, int batchSize) {
AbstractJobVertex component = components.get(componentName);
Configuration config = component.getConfiguration();
config.setInteger("batchSize", batchSize);
}
private Configuration setComponent(String componentName, private Configuration setComponent(String componentName,
UserSourceInvokable<? extends Tuple> InvokableObject, int parallelism, UserSourceInvokable<? extends Tuple> InvokableObject, int parallelism,
int subtasksPerInstance, AbstractJobVertex component) { int subtasksPerInstance, AbstractJobVertex component) {
...@@ -304,6 +304,11 @@ public class JobGraphBuilder { ...@@ -304,6 +304,11 @@ public class JobGraphBuilder {
return config; return config;
} }
public void setBatchSize(String componentName, int batchSize) {
Configuration config = components.get(componentName).getConfiguration();
config.setInteger("batchSize", batchSize);
}
/** /**
* Adds serialized invokable object to the JobVertex configuration * Adds serialized invokable object to the JobVertex configuration
* *
......
...@@ -28,37 +28,54 @@ public class StreamCollector<T extends Tuple> implements Collector<T> { ...@@ -28,37 +28,54 @@ public class StreamCollector<T extends Tuple> implements Collector<T> {
protected StreamRecord streamRecord; protected StreamRecord streamRecord;
protected int batchSize; protected int batchSize;
protected long batchTimeout;
protected int counter = 0; protected int counter = 0;
protected int channelID; protected int channelID;
private long timeOfLastRecordEmitted = System.currentTimeMillis();;
private List<RecordWriter<StreamRecord>> outputs; private List<RecordWriter<StreamRecord>> outputs;
public StreamCollector(int batchSize, int channelID, public StreamCollector(int batchSize, long batchTimeout, int channelID,
SerializationDelegate<Tuple> serializationDelegate, SerializationDelegate<Tuple> serializationDelegate,
List<RecordWriter<StreamRecord>> outputs) { List<RecordWriter<StreamRecord>> outputs) {
this.batchSize = batchSize; this.batchSize = batchSize;
this.batchTimeout = batchTimeout;
this.streamRecord = new ArrayStreamRecord(batchSize); this.streamRecord = new ArrayStreamRecord(batchSize);
this.streamRecord.setSeralizationDelegate(serializationDelegate); this.streamRecord.setSeralizationDelegate(serializationDelegate);
this.channelID = channelID; this.channelID = channelID;
this.outputs = outputs; this.outputs = outputs;
} }
public StreamCollector(int batchSize, int channelID, public StreamCollector(int batchSize, long batchTimeout, int channelID,
SerializationDelegate<Tuple> serializationDelegate) { SerializationDelegate<Tuple> serializationDelegate) {
this(batchSize, channelID, serializationDelegate, null); this(batchSize, batchTimeout, channelID, serializationDelegate, null);
} }
// TODO reconsider emitting mechanism at timeout (find a place to timeout)
@Override @Override
public void collect(T tuple) { public void collect(T tuple) {
streamRecord.setTuple(counter, StreamRecord.copyTuple(tuple)); streamRecord.setTuple(counter, StreamRecord.copyTuple(tuple));
counter++; counter++;
if (counter >= batchSize) { if (counter >= batchSize) {
counter = 0;
streamRecord.setId(channelID);
emit(streamRecord); emit(streamRecord);
// timeOfLastRecordEmitted = System.currentTimeMillis();
} else {
// timeout();
}
}
public void timeout() {
if (timeOfLastRecordEmitted + batchTimeout < System.currentTimeMillis()) {
StreamRecord truncatedRecord = new ArrayStreamRecord(streamRecord, counter);
emit(truncatedRecord);
timeOfLastRecordEmitted = System.currentTimeMillis();
} }
} }
private void emit(StreamRecord streamRecord) { private void emit(StreamRecord streamRecord) {
counter = 0;
streamRecord.setId(channelID);
if (outputs == null) { if (outputs == null) {
System.out.println(streamRecord); System.out.println(streamRecord);
} else { } else {
......
...@@ -33,9 +33,19 @@ import eu.stratosphere.util.Collector; ...@@ -33,9 +33,19 @@ import eu.stratosphere.util.Collector;
public class StreamExecutionEnvironment { public class StreamExecutionEnvironment {
JobGraphBuilder jobGraphBuilder; JobGraphBuilder jobGraphBuilder;
public StreamExecutionEnvironment() { public StreamExecutionEnvironment(int defaultBatchSize, long defaultBatchTimeoutMillis) {
jobGraphBuilder = new JobGraphBuilder("jobGraph", FaultToleranceType.NONE); if (defaultBatchSize < 1) {
throw new IllegalArgumentException("Batch size must be positive.");
}
if (defaultBatchTimeoutMillis < 1) {
throw new IllegalArgumentException("Batch timeout must be positive.");
}
jobGraphBuilder = new JobGraphBuilder("jobGraph", FaultToleranceType.NONE,
defaultBatchSize, defaultBatchTimeoutMillis);
}
public StreamExecutionEnvironment() {
this(1, 1000);
} }
private static class DummySource extends UserSourceInvokable<Tuple1<String>> { private static class DummySource extends UserSourceInvokable<Tuple1<String>> {
...@@ -54,6 +64,7 @@ public class StreamExecutionEnvironment { ...@@ -54,6 +64,7 @@ public class StreamExecutionEnvironment {
} }
public <T extends Tuple> void setBatchSize(DataStream<T> inputStream) { public <T extends Tuple> void setBatchSize(DataStream<T> inputStream) {
for (int i = 0; i < inputStream.connectIDs.size(); i++) { for (int i = 0; i < inputStream.connectIDs.size(); i++) {
jobGraphBuilder.setBatchSize(inputStream.connectIDs.get(i), jobGraphBuilder.setBatchSize(inputStream.connectIDs.get(i),
inputStream.batchSizes.get(i)); inputStream.batchSizes.get(i));
......
...@@ -107,8 +107,11 @@ public final class StreamComponentHelper<T extends AbstractInvokable> { ...@@ -107,8 +107,11 @@ public final class StreamComponentHelper<T extends AbstractInvokable> {
public StreamCollector<Tuple> setCollector(Configuration taskConfiguration, int id, public StreamCollector<Tuple> setCollector(Configuration taskConfiguration, int id,
List<RecordWriter<StreamRecord>> outputs) { List<RecordWriter<StreamRecord>> outputs) {
int batchSize = taskConfiguration.getInteger("batchSize", 1); int batchSize = taskConfiguration.getInteger("batchSize", 1);
collector = new StreamCollector<Tuple>(batchSize, id, outSerializationDelegate, outputs); long batchTimeout = taskConfiguration.getLong("batchTimeout", 1000);
collector = new StreamCollector<Tuple>(batchSize, batchTimeout, id,
outSerializationDelegate, outputs);
return collector; return collector;
} }
......
...@@ -45,14 +45,19 @@ public class ArrayStreamRecord extends StreamRecord { ...@@ -45,14 +45,19 @@ public class ArrayStreamRecord extends StreamRecord {
} }
public ArrayStreamRecord(StreamRecord record) { public ArrayStreamRecord(StreamRecord record) {
tupleBatch = new Tuple[record.getBatchSize()]; this(record, record.getBatchSize());
}
public ArrayStreamRecord(StreamRecord record, int truncatedSize) {
tupleBatch = new Tuple[truncatedSize];
this.uid = new UID(Arrays.copyOf(record.getId().getId(), 20)); this.uid = new UID(Arrays.copyOf(record.getId().getId(), 20));
for (int i = 0; i < record.getBatchSize(); ++i) { for (int i = 0; i < truncatedSize; ++i) {
this.tupleBatch[i] = copyTuple(record.getTuple(i)); this.tupleBatch[i] = copyTuple(record.getTuple(i));
} }
this.batchSize = tupleBatch.length; this.batchSize = tupleBatch.length;
} }
/** /**
* Creates a new batch of records containing the given Tuple array as * Creates a new batch of records containing the given Tuple array as
* elements * elements
......
...@@ -15,13 +15,10 @@ ...@@ -15,13 +15,10 @@
package eu.stratosphere.streaming.api; package eu.stratosphere.streaming.api;
import static org.junit.Assert.fail;
import java.util.Iterator; import java.util.Iterator;
import org.junit.Test; import org.junit.Test;
import eu.stratosphere.api.java.functions.FlatMapFunction;
import eu.stratosphere.api.java.functions.GroupReduceFunction; import eu.stratosphere.api.java.functions.GroupReduceFunction;
import eu.stratosphere.api.java.tuple.Tuple1; import eu.stratosphere.api.java.tuple.Tuple1;
import eu.stratosphere.util.Collector; import eu.stratosphere.util.Collector;
...@@ -72,8 +69,8 @@ public class BatchReduceTest { ...@@ -72,8 +69,8 @@ public class BatchReduceTest {
@Test @Test
public void test() throws Exception { public void test() throws Exception {
StreamExecutionEnvironment context = new StreamExecutionEnvironment(); StreamExecutionEnvironment context = new StreamExecutionEnvironment(4, 1000);
DataStream<Tuple1<Double>> dataStream0 = context.addSource(new MySource()).batch(4) DataStream<Tuple1<Double>> dataStream0 = context.addSource(new MySource())
.batchReduce(new MyBatchReduce()).addSink(new MySink()); .batchReduce(new MyBatchReduce()).addSink(new MySink());
context.execute(); context.execute();
......
package eu.stratosphere.streaming.api;
import static org.junit.Assert.assertEquals;
import org.junit.Test;
import eu.stratosphere.api.java.functions.FlatMapFunction;
import eu.stratosphere.api.java.tuple.Tuple1;
import eu.stratosphere.util.Collector;
public class BatchTest {
private static int count = 0;
private static final class MySource extends SourceFunction<Tuple1<String>> {
private Tuple1<String> outTuple = new Tuple1<String>();
@Override
public void invoke(Collector<Tuple1<String>> collector) throws Exception {
for (int i = 0; i < 20; i++) {
outTuple.f0 = "string #" + i;
collector.collect(outTuple);
}
}
}
private static final class MyMap extends FlatMapFunction<Tuple1<String>, Tuple1<String>> {
@Override
public void flatMap(Tuple1<String> value, Collector<Tuple1<String>> out) throws Exception {
out.collect(value);
}
}
private static final class MySink extends SinkFunction<Tuple1<String>> {
@Override
public void invoke(Tuple1<String> tuple) {
count++;
}
}
@Test
public void test() throws Exception {
StreamExecutionEnvironment context = new StreamExecutionEnvironment();
DataStream<Tuple1<String>> dataStream = context
.addSource(new MySource())
.flatMap(new MyMap()).batch(4)
.flatMap(new MyMap()).batch(2)
.flatMap(new MyMap()).batch(5)
.flatMap(new MyMap()).batch(4)
.addSink(new MySink());
context.execute();
assertEquals(20, count);
}
}
...@@ -70,7 +70,18 @@ public class FlatMapTest { ...@@ -70,7 +70,18 @@ public class FlatMapTest {
@Test @Test
public void test() throws Exception { public void test() throws Exception {
StreamExecutionEnvironment context = new StreamExecutionEnvironment(); try {
StreamExecutionEnvironment context2 = new StreamExecutionEnvironment(0, 1000);
fail();
} catch (IllegalArgumentException e) {
try {
StreamExecutionEnvironment context2 = new StreamExecutionEnvironment(1, 0);
fail();
} catch (IllegalArgumentException e2) {
}
}
StreamExecutionEnvironment context = new StreamExecutionEnvironment(2, 1000);
DataStream<Tuple1<String>> dataStream0 = context.addSource(new MySource()); DataStream<Tuple1<String>> dataStream0 = context.addSource(new MySource());
DataStream<Tuple1<String>> dataStream1 = context.addDummySource().connectWith(dataStream0) DataStream<Tuple1<String>> dataStream1 = context.addDummySource().connectWith(dataStream0)
...@@ -90,7 +101,7 @@ public class FlatMapTest { ...@@ -90,7 +101,7 @@ public class FlatMapTest {
FlatMapFunction<Tuple, Tuple> f = (FlatMapFunction<Tuple, Tuple>) in.readObject(); FlatMapFunction<Tuple, Tuple> f = (FlatMapFunction<Tuple, Tuple>) in.readObject();
StreamCollector<Tuple> s = new StreamCollector<Tuple>(1, 1, null); StreamCollector<Tuple> s = new StreamCollector<Tuple>(1, 1000, 1, null);
Tuple t = new Tuple1<String>("asd"); Tuple t = new Tuple1<String>("asd");
f.flatMap(t, s); f.flatMap(t, s);
......
...@@ -72,7 +72,7 @@ public class MapTest { ...@@ -72,7 +72,7 @@ public class MapTest {
MapFunction<Tuple, Tuple> f = (MapFunction<Tuple, Tuple>) in.readObject(); MapFunction<Tuple, Tuple> f = (MapFunction<Tuple, Tuple>) in.readObject();
StreamCollector<Tuple> s = new StreamCollector<Tuple>(1, 1, null); StreamCollector<Tuple> s = new StreamCollector<Tuple>(1, 1000, 1, null);
Tuple t = new Tuple1<String>("asd"); Tuple t = new Tuple1<String>("asd");
s.collect(f.map(t)); s.collect(f.map(t));
......
...@@ -25,13 +25,13 @@ public class StreamCollectorTest { ...@@ -25,13 +25,13 @@ public class StreamCollectorTest {
@Test @Test
public void testStreamCollector() { public void testStreamCollector() {
StreamCollector collector = new StreamCollector(10, 0, null); StreamCollector collector = new StreamCollector(10, 1000, 0, null);
assertEquals(10, collector.batchSize); assertEquals(10, collector.batchSize);
} }
@Test @Test
public void testCollect() { public void testCollect() {
StreamCollector collector = new StreamCollector(2, 0, null); StreamCollector collector = new StreamCollector(2, 1000, 0, null);
collector.collect(new Tuple1<Integer>(3)); collector.collect(new Tuple1<Integer>(3));
collector.collect(new Tuple1<Integer>(4)); collector.collect(new Tuple1<Integer>(4));
collector.collect(new Tuple1<Integer>(5)); collector.collect(new Tuple1<Integer>(5));
...@@ -39,6 +39,20 @@ public class StreamCollectorTest { ...@@ -39,6 +39,20 @@ public class StreamCollectorTest {
} }
@Test
public void testBatchSize() throws InterruptedException {
System.out.println("---------------");
StreamCollector collector = new StreamCollector(3, 100, 0, null);
collector.collect(new Tuple1<Integer>(0));
collector.collect(new Tuple1<Integer>(0));
collector.collect(new Tuple1<Integer>(0));
Thread.sleep(200);
collector.collect(new Tuple1<Integer>(2));
collector.collect(new Tuple1<Integer>(3));
System.out.println("---------------");
}
@Test @Test
public void testClose() { public void testClose() {
} }
......
...@@ -97,4 +97,17 @@ public class ArrayStreamRecordTest { ...@@ -97,4 +97,17 @@ public class ArrayStreamRecordTest {
} }
@Test
public void truncatedSizeTest() {
StreamRecord record = new ArrayStreamRecord(4);
record.setTuple(0, new Tuple1<Integer>(0));
record.setTuple(1, new Tuple1<Integer>(1));
record.setTuple(2, new Tuple1<Integer>(2));
record.setTuple(3, new Tuple1<Integer>(3));
StreamRecord truncatedRecord = new ArrayStreamRecord(record, 2);
assertEquals(2, truncatedRecord.batchSize);
assertEquals(0, truncatedRecord.getTuple(0).getField(0));
assertEquals(1, truncatedRecord.getTuple(1).getField(0));
}
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册