提交 ab14f901 编写于 作者: M Martin Junghanns 提交者: Till Rohrmann

[FLINK-2590] fixing DataSetUtils.zipWithUniqueId() and DataSetUtils.zipWithIndex()

* modified algorithm as explained in the issue
* updated method documentation

[FLINK-2590] reducing required bit shift size

* maximum bit size is changed to getNumberOfParallelSubTasks() - 1

This closes #1075.
上级 8c852c2a
...@@ -18,8 +18,9 @@ ...@@ -18,8 +18,9 @@
package org.apache.flink.api.java.utils; package org.apache.flink.api.java.utils;
import com.google.common.collect.Lists;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.RichMapPartitionFunction; import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.sampling.IntermediateSampleData;
import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.Utils; import org.apache.flink.api.java.Utils;
import org.apache.flink.api.java.functions.SampleInCoordinator; import org.apache.flink.api.java.functions.SampleInCoordinator;
...@@ -27,6 +28,7 @@ import org.apache.flink.api.java.functions.SampleInPartition; ...@@ -27,6 +28,7 @@ import org.apache.flink.api.java.functions.SampleInPartition;
import org.apache.flink.api.java.functions.SampleWithFraction; import org.apache.flink.api.java.functions.SampleWithFraction;
import org.apache.flink.api.java.operators.GroupReduceOperator; import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.MapPartitionOperator; import org.apache.flink.api.java.operators.MapPartitionOperator;
import org.apache.flink.api.java.sampling.IntermediateSampleData;
import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector; import org.apache.flink.util.Collector;
...@@ -49,11 +51,11 @@ public class DataSetUtils { ...@@ -49,11 +51,11 @@ public class DataSetUtils {
* @return a data set containing tuples of subtask index, number of elements mappings. * @return a data set containing tuples of subtask index, number of elements mappings.
*/ */
private static <T> DataSet<Tuple2<Integer, Long>> countElements(DataSet<T> input) { private static <T> DataSet<Tuple2<Integer, Long>> countElements(DataSet<T> input) {
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer,Long>>() { return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
@Override @Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception { public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
long counter = 0; long counter = 0;
for(T value: values) { for (T value : values) {
counter++; counter++;
} }
...@@ -63,8 +65,8 @@ public class DataSetUtils { ...@@ -63,8 +65,8 @@ public class DataSetUtils {
} }
/** /**
* Method that takes a set of subtask index, total number of elements mappings * Method that assigns a unique {@link Long} value to all elements in the input data set. The generated values are
* and assigns ids to all the elements from the input data set. * consecutive.
* *
* @param input the input data set * @param input the input data set
* @return a data set of tuple 2 consisting of consecutive ids and initial values. * @return a data set of tuple 2 consisting of consecutive ids and initial values.
...@@ -77,28 +79,36 @@ public class DataSetUtils { ...@@ -77,28 +79,36 @@ public class DataSetUtils {
long start = 0; long start = 0;
// compute the offset for each partition
@Override @Override
public void open(Configuration parameters) throws Exception { public void open(Configuration parameters) throws Exception {
super.open(parameters); super.open(parameters);
List<Tuple2<Integer, Long>> offsets = getRuntimeContext().getBroadcastVariable("counts"); List<Tuple2<Integer, Long>> offsets = getRuntimeContext().getBroadcastVariableWithInitializer(
"counts",
Collections.sort(offsets, new Comparator<Tuple2<Integer, Long>>() { new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
@Override @Override
public int compare(Tuple2<Integer, Long> o1, Tuple2<Integer, Long> o2) { public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> data) {
return compareInts(o1.f0, o2.f0); // sort the list by task id to calculate the correct offset
} List<Tuple2<Integer, Long>> sortedData = Lists.newArrayList(data);
}); Collections.sort(sortedData, new Comparator<Tuple2<Integer, Long>>() {
@Override
for(int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) { public int compare(Tuple2<Integer, Long> o1, Tuple2<Integer, Long> o2) {
return o1.f0.compareTo(o2.f0);
}
});
return sortedData;
}
});
// compute the offset for each partition
for (int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) {
start += offsets.get(i).f1; start += offsets.get(i).f1;
} }
} }
@Override @Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception { public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
for(T value: values) { for (T value: values) {
out.collect(new Tuple2<Long, T>(start++, value)); out.collect(new Tuple2<Long, T>(start++, value));
} }
} }
...@@ -106,12 +116,13 @@ public class DataSetUtils { ...@@ -106,12 +116,13 @@ public class DataSetUtils {
} }
/** /**
* Method that assigns unique Long labels to all the elements in the input data set by making use of the * Method that assigns a unique {@link Long} value to all elements in the input data set in the following way:
* following abstractions:
* <ul> * <ul>
* <li> a map function generates an n-bit (n - number of parallel tasks) ID based on its own index * <li> a map function is applied to the input data set
* <li> with each record, a counter c is increased * <li> each map task holds a counter c which is increased for each record
* <li> the unique label is then produced by shifting the counter c by the n-bit mapper ID * <li> c is shifted by n bits where n = log2(number of parallel tasks)
* <li> to create a unique ID among all tasks, the task id is added to the counter
* <li> for each record, the resulting counter is collected
* </ul> * </ul>
* *
* @param input the input data set * @param input the input data set
...@@ -121,6 +132,7 @@ public class DataSetUtils { ...@@ -121,6 +132,7 @@ public class DataSetUtils {
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() { return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {
long maxBitSize = getBitSize(Long.MAX_VALUE);
long shifter = 0; long shifter = 0;
long start = 0; long start = 0;
long taskId = 0; long taskId = 0;
...@@ -129,16 +141,16 @@ public class DataSetUtils { ...@@ -129,16 +141,16 @@ public class DataSetUtils {
@Override @Override
public void open(Configuration parameters) throws Exception { public void open(Configuration parameters) throws Exception {
super.open(parameters); super.open(parameters);
shifter = log2(getRuntimeContext().getNumberOfParallelSubtasks()); shifter = getBitSize(getRuntimeContext().getNumberOfParallelSubtasks() - 1);
taskId = getRuntimeContext().getIndexOfThisSubtask(); taskId = getRuntimeContext().getIndexOfThisSubtask();
} }
@Override @Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception { public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
for(T value: values) { for (T value : values) {
label = start << shifter + taskId; label = (start << shifter) + taskId;
if(log2(start) + shifter < log2(Long.MAX_VALUE)) { if (getBitSize(start) + shifter < maxBitSize) {
out.collect(new Tuple2<Long, T>(label, value)); out.collect(new Tuple2<Long, T>(label, value));
start++; start++;
} else { } else {
...@@ -241,11 +253,7 @@ public class DataSetUtils { ...@@ -241,11 +253,7 @@ public class DataSetUtils {
// UTIL METHODS // UTIL METHODS
// ************************************************************************* // *************************************************************************
private static int compareInts(int x, int y) { public static int getBitSize(long value){
return (x < y) ? -1 : ((x == y) ? 0 : 1);
}
private static int log2(long value){
if(value > Integer.MAX_VALUE) { if(value > Integer.MAX_VALUE) {
return 64 - Integer.numberOfLeadingZeros((int)(value >> 32)); return 64 - Integer.numberOfLeadingZeros((int)(value >> 32));
} else { } else {
......
...@@ -18,66 +18,59 @@ ...@@ -18,66 +18,59 @@
package org.apache.flink.test.util; package org.apache.flink.test.util;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.utils.DataSetUtils; import org.apache.flink.api.java.utils.DataSetUtils;
import org.junit.After; import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.runners.Parameterized;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Set;
@RunWith(Parameterized.class) @RunWith(Parameterized.class)
public class DataSetUtilsITCase extends MultipleProgramsTestBase { public class DataSetUtilsITCase extends MultipleProgramsTestBase {
private String resultPath;
private String expectedResult;
@Rule
public TemporaryFolder tempFolder = new TemporaryFolder();
public DataSetUtilsITCase(TestExecutionMode mode) { public DataSetUtilsITCase(TestExecutionMode mode) {
super(mode); super(mode);
} }
@Before
public void before() throws Exception{
resultPath = tempFolder.newFile().toURI().toString();
}
@Test @Test
public void testZipWithIndex() throws Exception { public void testZipWithIndex() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(1); long expectedSize = 100L;
DataSet<String> in = env.fromElements("A", "B", "C", "D", "E", "F"); DataSet<Long> numbers = env.generateSequence(0, expectedSize - 1);
DataSet<Tuple2<Long, String>> result = DataSetUtils.zipWithIndex(in); List<Tuple2<Long, Long>> result = Lists.newArrayList(DataSetUtils.zipWithIndex(numbers).collect());
result.writeAsCsv(resultPath, "\n", ","); Assert.assertEquals(expectedSize, result.size());
env.execute(); // sort result by created index
Collections.sort(result, new Comparator<Tuple2<Long, Long>>() {
expectedResult = "0,A\n" + "1,B\n" + "2,C\n" + "3,D\n" + "4,E\n" + "5,F"; @Override
public int compare(Tuple2<Long, Long> o1, Tuple2<Long, Long> o2) {
return o1.f0.compareTo(o2.f0);
}
});
// test if index is consecutive
for (int i = 0; i < expectedSize; i++) {
Assert.assertEquals(i, (long) result.get(i).f0);
}
} }
@Test @Test
public void testZipWithUniqueId() throws Exception { public void testZipWithUniqueId() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(1); long expectedSize = 100L;
DataSet<String> in = env.fromElements("A", "B", "C", "D", "E", "F"); DataSet<Long> numbers = env.generateSequence(1L, expectedSize);
DataSet<Tuple2<Long, String>> result = DataSetUtils.zipWithUniqueId(in);
result.writeAsCsv(resultPath, "\n", ","); Set<Tuple2<Long, Long>> result = Sets.newHashSet(DataSetUtils.zipWithUniqueId(numbers).collect());
env.execute();
expectedResult = "0,A\n" + "2,B\n" + "4,C\n" + "6,D\n" + "8,E\n" + "10,F";
}
@After Assert.assertEquals(expectedSize, result.size());
public void after() throws Exception{
compareResultsByLinesInMemory(expectedResult, resultPath);
} }
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册