提交 ab14f901 编写于 作者: M Martin Junghanns 提交者: Till Rohrmann

[FLINK-2590] fixing DataSetUtils.zipWithUniqueId() and DataSetUtils.zipWithIndex()

* modified algorithm as explained in the issue
* updated method documentation

[FLINK-2590] reducing required bit shift size

* maximum bit size is changed to getNumberOfParallelSubTasks() - 1

This closes #1075.
上级 8c852c2a
......@@ -18,8 +18,9 @@
package org.apache.flink.api.java.utils;
import com.google.common.collect.Lists;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.java.sampling.IntermediateSampleData;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.Utils;
import org.apache.flink.api.java.functions.SampleInCoordinator;
......@@ -27,6 +28,7 @@ import org.apache.flink.api.java.functions.SampleInPartition;
import org.apache.flink.api.java.functions.SampleWithFraction;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.MapPartitionOperator;
import org.apache.flink.api.java.sampling.IntermediateSampleData;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector;
......@@ -49,11 +51,11 @@ public class DataSetUtils {
* @return a data set containing tuples of subtask index, number of elements mappings.
*/
private static <T> DataSet<Tuple2<Integer, Long>> countElements(DataSet<T> input) {
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer,Long>>() {
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
@Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
long counter = 0;
for(T value: values) {
for (T value : values) {
counter++;
}
......@@ -63,8 +65,8 @@ public class DataSetUtils {
}
/**
* Method that takes a set of subtask index, total number of elements mappings
* and assigns ids to all the elements from the input data set.
* Method that assigns a unique {@link Long} value to all elements in the input data set. The generated values are
* consecutive.
*
* @param input the input data set
* @return a data set of tuple 2 consisting of consecutive ids and initial values.
......@@ -77,28 +79,36 @@ public class DataSetUtils {
long start = 0;
// compute the offset for each partition
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
List<Tuple2<Integer, Long>> offsets = getRuntimeContext().getBroadcastVariable("counts");
Collections.sort(offsets, new Comparator<Tuple2<Integer, Long>>() {
@Override
public int compare(Tuple2<Integer, Long> o1, Tuple2<Integer, Long> o2) {
return compareInts(o1.f0, o2.f0);
}
});
for(int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) {
List<Tuple2<Integer, Long>> offsets = getRuntimeContext().getBroadcastVariableWithInitializer(
"counts",
new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
@Override
public List<Tuple2<Integer, Long>> initializeBroadcastVariable(Iterable<Tuple2<Integer, Long>> data) {
// sort the list by task id to calculate the correct offset
List<Tuple2<Integer, Long>> sortedData = Lists.newArrayList(data);
Collections.sort(sortedData, new Comparator<Tuple2<Integer, Long>>() {
@Override
public int compare(Tuple2<Integer, Long> o1, Tuple2<Integer, Long> o2) {
return o1.f0.compareTo(o2.f0);
}
});
return sortedData;
}
});
// compute the offset for each partition
for (int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) {
start += offsets.get(i).f1;
}
}
@Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
for(T value: values) {
for (T value: values) {
out.collect(new Tuple2<Long, T>(start++, value));
}
}
......@@ -106,12 +116,13 @@ public class DataSetUtils {
}
/**
* Method that assigns unique Long labels to all the elements in the input data set by making use of the
* following abstractions:
* Method that assigns a unique {@link Long} value to all elements in the input data set in the following way:
* <ul>
* <li> a map function generates an n-bit (n - number of parallel tasks) ID based on its own index
* <li> with each record, a counter c is increased
* <li> the unique label is then produced by shifting the counter c by the n-bit mapper ID
* <li> a map function is applied to the input data set
* <li> each map task holds a counter c which is increased for each record
* <li> c is shifted by n bits where n = log2(number of parallel tasks)
* <li> to create a unique ID among all tasks, the task id is added to the counter
* <li> for each record, the resulting counter is collected
* </ul>
*
* @param input the input data set
......@@ -121,6 +132,7 @@ public class DataSetUtils {
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Long, T>>() {
long maxBitSize = getBitSize(Long.MAX_VALUE);
long shifter = 0;
long start = 0;
long taskId = 0;
......@@ -129,16 +141,16 @@ public class DataSetUtils {
@Override
public void open(Configuration parameters) throws Exception {
super.open(parameters);
shifter = log2(getRuntimeContext().getNumberOfParallelSubtasks());
shifter = getBitSize(getRuntimeContext().getNumberOfParallelSubtasks() - 1);
taskId = getRuntimeContext().getIndexOfThisSubtask();
}
@Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Long, T>> out) throws Exception {
for(T value: values) {
label = start << shifter + taskId;
for (T value : values) {
label = (start << shifter) + taskId;
if(log2(start) + shifter < log2(Long.MAX_VALUE)) {
if (getBitSize(start) + shifter < maxBitSize) {
out.collect(new Tuple2<Long, T>(label, value));
start++;
} else {
......@@ -241,11 +253,7 @@ public class DataSetUtils {
// UTIL METHODS
// *************************************************************************
private static int compareInts(int x, int y) {
return (x < y) ? -1 : ((x == y) ? 0 : 1);
}
private static int log2(long value){
public static int getBitSize(long value){
if(value > Integer.MAX_VALUE) {
return 64 - Integer.numberOfLeadingZeros((int)(value >> 32));
} else {
......
......@@ -18,66 +18,59 @@
package org.apache.flink.test.util;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Assert;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Set;
@RunWith(Parameterized.class)
public class DataSetUtilsITCase extends MultipleProgramsTestBase {
private String resultPath;
private String expectedResult;
@Rule
public TemporaryFolder tempFolder = new TemporaryFolder();
public DataSetUtilsITCase(TestExecutionMode mode) {
super(mode);
}
@Before
public void before() throws Exception{
resultPath = tempFolder.newFile().toURI().toString();
}
@Test
public void testZipWithIndex() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(1);
DataSet<String> in = env.fromElements("A", "B", "C", "D", "E", "F");
DataSet<Tuple2<Long, String>> result = DataSetUtils.zipWithIndex(in);
result.writeAsCsv(resultPath, "\n", ",");
env.execute();
expectedResult = "0,A\n" + "1,B\n" + "2,C\n" + "3,D\n" + "4,E\n" + "5,F";
long expectedSize = 100L;
DataSet<Long> numbers = env.generateSequence(0, expectedSize - 1);
List<Tuple2<Long, Long>> result = Lists.newArrayList(DataSetUtils.zipWithIndex(numbers).collect());
Assert.assertEquals(expectedSize, result.size());
// sort result by created index
Collections.sort(result, new Comparator<Tuple2<Long, Long>>() {
@Override
public int compare(Tuple2<Long, Long> o1, Tuple2<Long, Long> o2) {
return o1.f0.compareTo(o2.f0);
}
});
// test if index is consecutive
for (int i = 0; i < expectedSize; i++) {
Assert.assertEquals(i, (long) result.get(i).f0);
}
}
@Test
public void testZipWithUniqueId() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(1);
DataSet<String> in = env.fromElements("A", "B", "C", "D", "E", "F");
DataSet<Tuple2<Long, String>> result = DataSetUtils.zipWithUniqueId(in);
long expectedSize = 100L;
DataSet<Long> numbers = env.generateSequence(1L, expectedSize);
result.writeAsCsv(resultPath, "\n", ",");
env.execute();
expectedResult = "0,A\n" + "2,B\n" + "4,C\n" + "6,D\n" + "8,E\n" + "10,F";
}
Set<Tuple2<Long, Long>> result = Sets.newHashSet(DataSetUtils.zipWithUniqueId(numbers).collect());
@After
public void after() throws Exception{
compareResultsByLinesInMemory(expectedResult, resultPath);
Assert.assertEquals(expectedSize, result.size());
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册