提交 4b03b9bb 编写于 作者: G Gyula Fora 提交者: Stephan Ewen

[streaming] updated fault tolerance buffer and jobgraphbuilder to properly...

[streaming] updated fault tolerance buffer and jobgraphbuilder to properly handle broadcast partitioning
上级 f9cbd54d
......@@ -33,7 +33,7 @@ import eu.stratosphere.nephele.io.RecordWriter;
*/
public class FaultToleranceBuffer {
private long TIMEOUT = 1000;
private long TIMEOUT = 10000;
private Long timeOfLastUpdate;
private Map<String, StreamRecord> recordBuffer;
private Map<String, Integer> ackCounter;
......@@ -50,17 +50,20 @@ public class FaultToleranceBuffer {
* channel ID
*
* @param outputs
* List of outputs
* List of outputs
* @param channelID
* ID of the task object that uses this buffer
* ID of the task object that uses this buffer
* @param numberOfChannels
* Number of output channels for the component
*/
public FaultToleranceBuffer(List<RecordWriter<StreamRecord>> outputs,
String channelID) {
String channelID, int numberOfChannels) {
this.timeOfLastUpdate = System.currentTimeMillis();
this.outputs = outputs;
this.recordBuffer = new HashMap<String, StreamRecord>();
this.ackCounter = new HashMap<String, Integer>();
this.numberOfOutputs = outputs.size();
this.numberOfOutputs = numberOfChannels;
this.channelID = channelID;
this.recordsByTime = new TreeMap<Long, Set<String>>();
this.recordTimestamps = new HashMap<String, Long>();
......@@ -72,26 +75,24 @@ public class FaultToleranceBuffer {
*
*/
public void addRecord(StreamRecord streamRecord) {
String id=streamRecord.getId();
String id = streamRecord.getId();
recordBuffer.put(id, streamRecord.copy());
ackCounter.put(id, numberOfOutputs);
addTimestamp(id);
}
/**
* Checks for records that have timed out since the last check and fails
* them.
* Checks for records that have timed out since the last check and fails them.
*
* @param currentTime
* Time when the check should be made, usually current system
* time.
* Time when the check should be made, usually current system time.
* @return Returns the list of the records that have timed out.
*/
List<String> timeoutRecords(Long currentTime) {
if (timeOfLastUpdate + TIMEOUT < currentTime) {
List<String> timedOutRecords = new LinkedList<String>();
Map<Long, Set<String>> timedOut = recordsByTime.subMap(0L,
currentTime - TIMEOUT);
Map<Long, Set<String>> timedOut = recordsByTime.subMap(0L, currentTime
- TIMEOUT);
for (Set<String> recordSet : timedOut.values()) {
if (!recordSet.isEmpty()) {
......@@ -114,20 +115,20 @@ public class FaultToleranceBuffer {
/**
* Stores time stamp for a record by recordID and also adds the record to a
* map which maps a time stamp to the IDs of records that were emitted at
* that time.
* map which maps a time stamp to the IDs of records that were emitted at that
* time.
* <p>
* Later used for timeouts.
*
* @param recordID
* ID of the record
* ID of the record
*/
public void addTimestamp(String recordID) {
Long currentTime = System.currentTimeMillis();
recordTimestamps.put(recordID, currentTime);
Set<String> recordSet = recordsByTime.get(currentTime);
if (recordSet != null) {
recordSet.add(recordID);
} else {
......@@ -141,7 +142,7 @@ public class FaultToleranceBuffer {
* Returns a StreamRecord after removing it from the buffer
*
* @param recordID
* The ID of the record that will be popped
* The ID of the record that will be popped
*/
public StreamRecord popRecord(String recordID) {
System.out.println("Pop ID: " + recordID);
......@@ -149,15 +150,15 @@ public class FaultToleranceBuffer {
}
/**
* Removes a StreamRecord by ID from the fault tolerance buffer, further
* acks will have no effects for this record.
* Removes a StreamRecord by ID from the fault tolerance buffer, further acks
* will have no effects for this record.
*
* @param recordID
* The ID of the record that will be removed
* The ID of the record that will be removed
*
*/
StreamRecord removeRecord(String recordID) {
ackCounter.remove(recordID);
try {
recordsByTime.get(recordTimestamps.remove(recordID)).remove(recordID);
......@@ -175,7 +176,7 @@ public class FaultToleranceBuffer {
* acknowledgments, removes it from the buffer
*
* @param recordID
* ID of the record that has been acknowledged
* ID of the record that has been acknowledged
*/
// TODO: find a place to call timeoutRecords
public void ackRecord(String recordID) {
......@@ -195,7 +196,7 @@ public class FaultToleranceBuffer {
* stores it with a new ID.
*
* @param recordID
* ID of the record that has been failed
* ID of the record that has been failed
*/
public void failRecord(String recordID) {
// Create new id to avoid double counting acks
......@@ -209,7 +210,7 @@ public class FaultToleranceBuffer {
* Emit give record to all output channels
*
* @param record
* Record to be re-emitted
* Record to be re-emitted
*/
public void reEmit(StreamRecord record) {
for (RecordWriter<StreamRecord> output : outputs) {
......
......@@ -48,6 +48,8 @@ public class JobGraphBuilder {
private final JobGraph jobGraph;
private Map<String, AbstractJobVertex> components;
private Map<String, Integer> numberOfInstances;
private Map<String, Integer> numberOfOutputChannels;
/**
* Creates a new JobGraph with the given name
......@@ -59,6 +61,9 @@ public class JobGraphBuilder {
jobGraph = new JobGraph(jobGraphName);
components = new HashMap<String, AbstractJobVertex>();
numberOfInstances = new HashMap<String, Integer>();
numberOfOutputChannels = new HashMap<String, Integer>();
}
/**
......@@ -79,6 +84,7 @@ public class JobGraphBuilder {
.getConfiguration();
config.setClass("userfunction", InvokableClass);
components.put(sourceName, source);
numberOfInstances.put(sourceName, 1);
}
/**
......@@ -101,6 +107,7 @@ public class JobGraphBuilder {
.getConfiguration();
config.setClass("userfunction", InvokableClass);
components.put(taskName, task);
numberOfInstances.put(taskName, parallelism);
}
/**
......@@ -120,6 +127,7 @@ public class JobGraphBuilder {
.getConfiguration();
config.setClass("userfunction", InvokableClass);
components.put(sinkName, sink);
numberOfInstances.put(sinkName, 1);
}
/**
......@@ -174,6 +182,17 @@ public class JobGraphBuilder {
connect(upStreamComponentName, downStreamComponentName,
BroadcastPartitioner.class, ChannelType.INMEMORY);
if (numberOfOutputChannels.containsKey(upStreamComponentName)) {
numberOfOutputChannels.put(
upStreamComponentName,
numberOfOutputChannels.get(upStreamComponentName)
+ numberOfInstances.get(downStreamComponentName));
} else {
numberOfOutputChannels.put(upStreamComponentName,
numberOfInstances.get(downStreamComponentName));
}
}
/**
......@@ -219,6 +238,8 @@ public class JobGraphBuilder {
"partitionerIntParam_"
+ upStreamComponent.getNumberOfForwardConnections(), keyPosition);
addOutputChannels(upStreamComponentName);
} catch (JobGraphDefinitionException e) {
e.printStackTrace();
}
......@@ -240,6 +261,9 @@ public class JobGraphBuilder {
connect(upStreamComponentName, downStreamComponentName,
GlobalPartitioner.class, ChannelType.INMEMORY);
addOutputChannels(upStreamComponentName);
}
/**
......@@ -258,6 +282,17 @@ public class JobGraphBuilder {
connect(upStreamComponentName, downStreamComponentName,
ShufflePartitioner.class, ChannelType.INMEMORY);
addOutputChannels(upStreamComponentName);
}
private void addOutputChannels(String upStreamComponentName) {
if (numberOfOutputChannels.containsKey(upStreamComponentName)) {
numberOfOutputChannels.put(upStreamComponentName,
numberOfOutputChannels.get(upStreamComponentName) + 1);
} else {
numberOfOutputChannels.put(upStreamComponentName, 1);
}
}
private void setNumberOfJobInputs() {
......@@ -272,6 +307,13 @@ public class JobGraphBuilder {
component.getConfiguration().setInteger("numberOfOutputs",
component.getNumberOfForwardConnections());
}
for (String component : numberOfOutputChannels.keySet()) {
components
.get(component)
.getConfiguration()
.setInteger("numberOfOutputChannels",
numberOfOutputChannels.get(component));
}
}
/**
......
......@@ -67,8 +67,8 @@ public class StreamSource extends AbstractInputTask<RandIS> {
} catch (StreamComponentException e) {
e.printStackTrace();
}
recordBuffer = new FaultToleranceBuffer(outputs, sourceInstanceID);
recordBuffer = new FaultToleranceBuffer(outputs, sourceInstanceID, taskConfiguration.getInteger("numberOfOutputChannels", -1));
userFunction = (UserSourceInvokable) streamSourceHelper.getUserFunction(
taskConfiguration, outputs, sourceInstanceID, recordBuffer);
streamSourceHelper.setAckListener(recordBuffer, sourceInstanceID, outputs);
......
......@@ -64,7 +64,7 @@ public class StreamTask extends AbstractTask {
e.printStackTrace();
}
recordBuffer = new FaultToleranceBuffer(outputs, taskInstanceID);
recordBuffer = new FaultToleranceBuffer(outputs, taskInstanceID,taskConfiguration.getInteger("numberOfOutputChannels", -1));
userFunction = (UserTaskInvokable) streamTaskHelper.getUserFunction(
taskConfiguration, outputs, taskInstanceID, recordBuffer);
streamTaskHelper.setAckListener(recordBuffer, taskInstanceID, outputs);
......
......@@ -37,8 +37,7 @@ public class FaultToleranceBufferTest {
@Before
public void setFaultTolerancyBuffer() {
outputs = new LinkedList<RecordWriter<StreamRecord>>();
faultTolerancyBuffer = new FaultToleranceBuffer(outputs, "1");
faultTolerancyBuffer.setNumberOfOutputs(3);
faultTolerancyBuffer = new FaultToleranceBuffer(outputs, "1", 3);
}
@Test
......@@ -53,8 +52,10 @@ public class FaultToleranceBufferTest {
StreamRecord record = (new StreamRecord(1)).setId("1");
record.addRecord(new StringValue("V1"));
faultTolerancyBuffer.addRecord(record);
assertEquals((Integer) 3, faultTolerancyBuffer.getAckCounter().get(record.getId()));
assertArrayEquals(record.getRecord(0),faultTolerancyBuffer.getRecordBuffer().get(record.getId()).getRecord(0));
assertEquals((Integer) 3,
faultTolerancyBuffer.getAckCounter().get(record.getId()));
assertArrayEquals(record.getRecord(0), faultTolerancyBuffer
.getRecordBuffer().get(record.getId()).getRecord(0));
}
@Test
......@@ -71,30 +72,31 @@ public class FaultToleranceBufferTest {
assertArrayEquals(records,
faultTolerancyBuffer.getRecordsByTime().get(recordTimeStamp).toArray());
try {
Thread.sleep(2);
} catch (InterruptedException e) {
e.printStackTrace();
}
faultTolerancyBuffer.addTimestamp("1-1338");
faultTolerancyBuffer.addTimestamp("1-1339");
long recordTimeStamp1 = faultTolerancyBuffer.getRecordTimestamps().get(
"1-1338");
long recordTimeStamp2 = faultTolerancyBuffer.getRecordTimestamps().get(
"1-1339");
records = new String[] { "1-1338","1-1339"};
if(recordTimeStamp1==recordTimeStamp2){
assertTrue(faultTolerancyBuffer.getRecordsByTime().get(recordTimeStamp1).contains("1-1338"));
assertTrue(faultTolerancyBuffer.getRecordsByTime().get(recordTimeStamp1).contains("1-1339"));
assertTrue(faultTolerancyBuffer.getRecordsByTime().get(recordTimeStamp1).size()==2);
records = new String[] { "1-1338", "1-1339" };
if (recordTimeStamp1 == recordTimeStamp2) {
assertTrue(faultTolerancyBuffer.getRecordsByTime().get(recordTimeStamp1)
.contains("1-1338"));
assertTrue(faultTolerancyBuffer.getRecordsByTime().get(recordTimeStamp1)
.contains("1-1339"));
assertTrue(faultTolerancyBuffer.getRecordsByTime().get(recordTimeStamp1)
.size() == 2);
}
}
@Test
......@@ -102,35 +104,46 @@ public class FaultToleranceBufferTest {
StreamRecord record1 = (new StreamRecord(1)).setId("1");
record1.addRecord(new StringValue("V1"));
faultTolerancyBuffer.addRecord(record1);
assertArrayEquals(record1.getRecord(0), faultTolerancyBuffer.popRecord(record1.getId()).getRecord(0));
assertArrayEquals(record1.getRecord(0),
faultTolerancyBuffer.popRecord(record1.getId()).getRecord(0));
System.out.println("---------");
}
@Test
public void testRemoveRecord() {
StreamRecord record1 = (new StreamRecord(1)).setId("1");
record1.addRecord(new StringValue("V1"));
StreamRecord record2 = (new StreamRecord(1)).setId("1");
record2.addRecord(new StringValue("V2"));
faultTolerancyBuffer.addRecord(record1);
faultTolerancyBuffer.addRecord(record2);
Long record1TS=faultTolerancyBuffer.getRecordTimestamps().get(record1.getId());
Long record2TS=faultTolerancyBuffer.getRecordTimestamps().get(record2.getId());
Long record1TS = faultTolerancyBuffer.getRecordTimestamps().get(
record1.getId());
Long record2TS = faultTolerancyBuffer.getRecordTimestamps().get(
record2.getId());
faultTolerancyBuffer.removeRecord(record1.getId());
assertTrue(faultTolerancyBuffer.getRecordBuffer().containsKey(record2.getId()));
assertTrue(faultTolerancyBuffer.getAckCounter().containsKey(record2.getId()));
assertTrue(faultTolerancyBuffer.getRecordTimestamps().containsKey(record2.getId()));
assertTrue(faultTolerancyBuffer.getRecordsByTime().get(record2TS).contains(record2.getId()));
assertFalse(faultTolerancyBuffer.getRecordBuffer().containsKey(record1.getId()));
assertFalse(faultTolerancyBuffer.getAckCounter().containsKey(record1.getId()));
assertFalse(faultTolerancyBuffer.getRecordTimestamps().containsKey(record1.getId()));
assertFalse(faultTolerancyBuffer.getRecordsByTime().get(record1TS).contains(record1.getId()));
assertTrue(faultTolerancyBuffer.getRecordBuffer().containsKey(
record2.getId()));
assertTrue(faultTolerancyBuffer.getAckCounter()
.containsKey(record2.getId()));
assertTrue(faultTolerancyBuffer.getRecordTimestamps().containsKey(
record2.getId()));
assertTrue(faultTolerancyBuffer.getRecordsByTime().get(record2TS)
.contains(record2.getId()));
assertFalse(faultTolerancyBuffer.getRecordBuffer().containsKey(
record1.getId()));
assertFalse(faultTolerancyBuffer.getAckCounter().containsKey(
record1.getId()));
assertFalse(faultTolerancyBuffer.getRecordTimestamps().containsKey(
record1.getId()));
assertFalse(faultTolerancyBuffer.getRecordsByTime().get(record1TS)
.contains(record1.getId()));
}
@Test
......@@ -138,22 +151,32 @@ public class FaultToleranceBufferTest {
StreamRecord record1 = (new StreamRecord(1)).setId("1");
record1.addRecord(new StringValue("V1"));
faultTolerancyBuffer.addRecord(record1);
Long record1TS=faultTolerancyBuffer.getRecordTimestamps().get(record1.getId());
Long record1TS = faultTolerancyBuffer.getRecordTimestamps().get(
record1.getId());
faultTolerancyBuffer.ackRecord(record1.getId());
faultTolerancyBuffer.ackRecord(record1.getId());
assertEquals((Integer) 1, faultTolerancyBuffer.getAckCounter().get(record1.getId()));
assertTrue(faultTolerancyBuffer.getRecordBuffer().containsKey(record1.getId()));
assertTrue(faultTolerancyBuffer.getAckCounter().containsKey(record1.getId()));
assertTrue(faultTolerancyBuffer.getRecordTimestamps().containsKey(record1.getId()));
assertTrue(faultTolerancyBuffer.getRecordsByTime().get(record1TS).contains(record1.getId()));
assertEquals((Integer) 1,
faultTolerancyBuffer.getAckCounter().get(record1.getId()));
assertTrue(faultTolerancyBuffer.getRecordBuffer().containsKey(
record1.getId()));
assertTrue(faultTolerancyBuffer.getAckCounter()
.containsKey(record1.getId()));
assertTrue(faultTolerancyBuffer.getRecordTimestamps().containsKey(
record1.getId()));
assertTrue(faultTolerancyBuffer.getRecordsByTime().get(record1TS)
.contains(record1.getId()));
faultTolerancyBuffer.ackRecord(record1.getId());
assertFalse(faultTolerancyBuffer.getRecordBuffer().containsKey(record1.getId()));
assertFalse(faultTolerancyBuffer.getAckCounter().containsKey(record1.getId()));
assertFalse(faultTolerancyBuffer.getRecordTimestamps().containsKey(record1.getId()));
assertFalse(faultTolerancyBuffer.getRecordsByTime().get(record1TS).contains(record1.getId()));
assertFalse(faultTolerancyBuffer.getRecordBuffer().containsKey(
record1.getId()));
assertFalse(faultTolerancyBuffer.getAckCounter().containsKey(
record1.getId()));
assertFalse(faultTolerancyBuffer.getRecordTimestamps().containsKey(
record1.getId()));
assertFalse(faultTolerancyBuffer.getRecordsByTime().get(record1TS)
.contains(record1.getId()));
faultTolerancyBuffer.ackRecord(record1.getId());
}
......@@ -162,109 +185,125 @@ public class FaultToleranceBufferTest {
StreamRecord record1 = (new StreamRecord(1)).setId("1");
record1.addRecord(new StringValue("V1"));
faultTolerancyBuffer.addRecord(record1);
Long record1TS=faultTolerancyBuffer.getRecordTimestamps().get(record1.getId());
assertTrue(faultTolerancyBuffer.getRecordBuffer().containsKey(record1.getId()));
assertTrue(faultTolerancyBuffer.getAckCounter().containsKey(record1.getId()));
assertTrue(faultTolerancyBuffer.getRecordTimestamps().containsKey(record1.getId()));
assertTrue(faultTolerancyBuffer.getRecordsByTime().get(record1TS).contains(record1.getId()));
Long record1TS = faultTolerancyBuffer.getRecordTimestamps().get(
record1.getId());
assertTrue(faultTolerancyBuffer.getRecordBuffer().containsKey(
record1.getId()));
assertTrue(faultTolerancyBuffer.getAckCounter()
.containsKey(record1.getId()));
assertTrue(faultTolerancyBuffer.getRecordTimestamps().containsKey(
record1.getId()));
assertTrue(faultTolerancyBuffer.getRecordsByTime().get(record1TS)
.contains(record1.getId()));
String prevID = record1.getId();
faultTolerancyBuffer.failRecord(record1.getId());
Long record2TS=faultTolerancyBuffer.getRecordTimestamps().get(record1.getId());
assertFalse(faultTolerancyBuffer.getRecordBuffer().containsKey(prevID));
Long record2TS = faultTolerancyBuffer.getRecordTimestamps().get(
record1.getId());
assertFalse(faultTolerancyBuffer.getRecordBuffer().containsKey(prevID));
assertFalse(faultTolerancyBuffer.getAckCounter().containsKey(prevID));
assertFalse(faultTolerancyBuffer.getRecordTimestamps().containsKey(prevID));
assertFalse(faultTolerancyBuffer.getRecordsByTime().get(record1TS).contains(prevID));
assertFalse(faultTolerancyBuffer.getRecordsByTime().get(record1TS)
.contains(prevID));
faultTolerancyBuffer.ackRecord(prevID);
faultTolerancyBuffer.ackRecord(prevID);
faultTolerancyBuffer.ackRecord(prevID);
assertTrue(faultTolerancyBuffer.getRecordBuffer().containsKey(record1.getId()));
assertTrue(faultTolerancyBuffer.getAckCounter().containsKey(record1.getId()));
assertTrue(faultTolerancyBuffer.getRecordTimestamps().containsKey(record1.getId()));
assertTrue(faultTolerancyBuffer.getRecordsByTime().get(record2TS).contains(record1.getId()));
assertTrue(faultTolerancyBuffer.getRecordBuffer().containsKey(
record1.getId()));
assertTrue(faultTolerancyBuffer.getAckCounter()
.containsKey(record1.getId()));
assertTrue(faultTolerancyBuffer.getRecordTimestamps().containsKey(
record1.getId()));
assertTrue(faultTolerancyBuffer.getRecordsByTime().get(record2TS)
.contains(record1.getId()));
System.out.println("---------");
}
//TODO: create more tests for this method
// TODO: create more tests for this method
@Test
public void testTimeOutRecords() {
faultTolerancyBuffer.setTIMEOUT(1000);
StreamRecord record1 = (new StreamRecord(1)).setId("1");
record1.addRecord(new StringValue("V1"));
StreamRecord record2 = (new StreamRecord(1)).setId("1");
record2.addRecord(new StringValue("V2"));
StreamRecord record3 = (new StreamRecord(1)).setId("1");
record3.addRecord(new StringValue("V3"));
faultTolerancyBuffer.addRecord(record1);
faultTolerancyBuffer.addRecord(record2);
try {
Thread.sleep(500);
} catch (Exception e) {
}
faultTolerancyBuffer.addRecord(record3);
Long record1TS=faultTolerancyBuffer.getRecordTimestamps().get(record1.getId());
Long record2TS=faultTolerancyBuffer.getRecordTimestamps().get(record2.getId());
Long record1TS = faultTolerancyBuffer.getRecordTimestamps().get(
record1.getId());
Long record2TS = faultTolerancyBuffer.getRecordTimestamps().get(
record2.getId());
faultTolerancyBuffer.ackRecord(record1.getId());
faultTolerancyBuffer.ackRecord(record1.getId());
faultTolerancyBuffer.ackRecord(record1.getId());
faultTolerancyBuffer.ackRecord(record2.getId());
faultTolerancyBuffer.ackRecord(record3.getId());
faultTolerancyBuffer.ackRecord(record3.getId());
try {
Thread.sleep(501);
} catch (InterruptedException e) {
}
List<String> timedOutRecords = faultTolerancyBuffer.timeoutRecords(System.currentTimeMillis());
System.out.println("timedOutRecords: "+ timedOutRecords);
List<String> timedOutRecords = faultTolerancyBuffer.timeoutRecords(System
.currentTimeMillis());
System.out.println("timedOutRecords: " + timedOutRecords);
assertEquals(1, timedOutRecords.size());
assertFalse(timedOutRecords.contains(record1.getId()));
assertFalse(timedOutRecords.contains(record1.getId()));
assertFalse(faultTolerancyBuffer.getRecordsByTime().containsKey(record1TS));
assertFalse(faultTolerancyBuffer.getRecordsByTime().containsKey(record2TS));
assertTrue(faultTolerancyBuffer.getRecordBuffer().containsKey(record2.getId()));
assertTrue(faultTolerancyBuffer.getAckCounter().containsKey(record2.getId()));
assertTrue(faultTolerancyBuffer.getRecordTimestamps().containsKey(record2.getId()));
assertTrue(faultTolerancyBuffer.getRecordBuffer().containsKey(
record2.getId()));
assertTrue(faultTolerancyBuffer.getAckCounter()
.containsKey(record2.getId()));
assertTrue(faultTolerancyBuffer.getRecordTimestamps().containsKey(
record2.getId()));
System.out.println(faultTolerancyBuffer.getAckCounter());
try {
Thread.sleep(100);
} catch (InterruptedException e) {
}
timedOutRecords = faultTolerancyBuffer.timeoutRecords(System.currentTimeMillis());
assertEquals(null,timedOutRecords);
timedOutRecords = faultTolerancyBuffer.timeoutRecords(System
.currentTimeMillis());
assertEquals(null, timedOutRecords);
try {
Thread.sleep(901);
} catch (InterruptedException e) {
}
timedOutRecords = faultTolerancyBuffer.timeoutRecords(System.currentTimeMillis());
System.out.println("timedOutRecords: "+ timedOutRecords);
timedOutRecords = faultTolerancyBuffer.timeoutRecords(System
.currentTimeMillis());
System.out.println("timedOutRecords: " + timedOutRecords);
assertEquals(2, timedOutRecords.size());
System.out.println(faultTolerancyBuffer.getAckCounter());
System.out.println("---------");
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册