From ab14f90142fd69426bb695cbdb641f0a5a0c46f7 Mon Sep 17 00:00:00 2001 From: Martin Junghanns Date: Sat, 29 Aug 2015 22:51:19 +0200 Subject: [PATCH] [FLINK-2590] fixing DataSetUtils.zipWithUniqueId() and DataSetUtils.zipWithIndex() * modified algorithm as explained in the issue * updated method documentation [FLINK-2590] reducing required bit shift size * maximum bit size is changed to getNumberOfParallelSubTasks() - 1 This closes #1075. --- .../flink/api/java/utils/DataSetUtils.java | 70 +++++++++++-------- .../flink/test/util/DataSetUtilsITCase.java | 65 ++++++++--------- 2 files changed, 68 insertions(+), 67 deletions(-) diff --git a/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java b/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java index d2689259257..722fc6b3f97 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/utils/DataSetUtils.java @@ -18,8 +18,9 @@ package org.apache.flink.api.java.utils; +import com.google.common.collect.Lists; +import org.apache.flink.api.common.functions.BroadcastVariableInitializer; import org.apache.flink.api.common.functions.RichMapPartitionFunction; -import org.apache.flink.api.java.sampling.IntermediateSampleData; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.Utils; import org.apache.flink.api.java.functions.SampleInCoordinator; @@ -27,6 +28,7 @@ import org.apache.flink.api.java.functions.SampleInPartition; import org.apache.flink.api.java.functions.SampleWithFraction; import org.apache.flink.api.java.operators.GroupReduceOperator; import org.apache.flink.api.java.operators.MapPartitionOperator; +import org.apache.flink.api.java.sampling.IntermediateSampleData; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.util.Collector; @@ -49,11 +51,11 @@ public class DataSetUtils { * @return a data set containing tuples of subtask index, number of elements mappings. */ private static DataSet> countElements(DataSet input) { - return input.mapPartition(new RichMapPartitionFunction>() { + return input.mapPartition(new RichMapPartitionFunction>() { @Override public void mapPartition(Iterable values, Collector> out) throws Exception { long counter = 0; - for(T value: values) { + for (T value : values) { counter++; } @@ -63,8 +65,8 @@ public class DataSetUtils { } /** - * Method that takes a set of subtask index, total number of elements mappings - * and assigns ids to all the elements from the input data set. + * Method that assigns a unique {@link Long} value to all elements in the input data set. The generated values are + * consecutive. * * @param input the input data set * @return a data set of tuple 2 consisting of consecutive ids and initial values. @@ -77,28 +79,36 @@ public class DataSetUtils { long start = 0; - // compute the offset for each partition @Override public void open(Configuration parameters) throws Exception { super.open(parameters); - List> offsets = getRuntimeContext().getBroadcastVariable("counts"); - - Collections.sort(offsets, new Comparator>() { - @Override - public int compare(Tuple2 o1, Tuple2 o2) { - return compareInts(o1.f0, o2.f0); - } - }); - - for(int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) { + List> offsets = getRuntimeContext().getBroadcastVariableWithInitializer( + "counts", + new BroadcastVariableInitializer, List>>() { + @Override + public List> initializeBroadcastVariable(Iterable> data) { + // sort the list by task id to calculate the correct offset + List> sortedData = Lists.newArrayList(data); + Collections.sort(sortedData, new Comparator>() { + @Override + public int compare(Tuple2 o1, Tuple2 o2) { + return o1.f0.compareTo(o2.f0); + } + }); + return sortedData; + } + }); + + // compute the offset for each partition + for (int i = 0; i < getRuntimeContext().getIndexOfThisSubtask(); i++) { start += offsets.get(i).f1; } } @Override public void mapPartition(Iterable values, Collector> out) throws Exception { - for(T value: values) { + for (T value: values) { out.collect(new Tuple2(start++, value)); } } @@ -106,12 +116,13 @@ public class DataSetUtils { } /** - * Method that assigns unique Long labels to all the elements in the input data set by making use of the - * following abstractions: + * Method that assigns a unique {@link Long} value to all elements in the input data set in the following way: *
    - *
  • a map function generates an n-bit (n - number of parallel tasks) ID based on its own index - *
  • with each record, a counter c is increased - *
  • the unique label is then produced by shifting the counter c by the n-bit mapper ID + *
  • a map function is applied to the input data set + *
  • each map task holds a counter c which is increased for each record + *
  • c is shifted by n bits where n = log2(number of parallel tasks) + *
  • to create a unique ID among all tasks, the task id is added to the counter + *
  • for each record, the resulting counter is collected *
* * @param input the input data set @@ -121,6 +132,7 @@ public class DataSetUtils { return input.mapPartition(new RichMapPartitionFunction>() { + long maxBitSize = getBitSize(Long.MAX_VALUE); long shifter = 0; long start = 0; long taskId = 0; @@ -129,16 +141,16 @@ public class DataSetUtils { @Override public void open(Configuration parameters) throws Exception { super.open(parameters); - shifter = log2(getRuntimeContext().getNumberOfParallelSubtasks()); + shifter = getBitSize(getRuntimeContext().getNumberOfParallelSubtasks() - 1); taskId = getRuntimeContext().getIndexOfThisSubtask(); } @Override public void mapPartition(Iterable values, Collector> out) throws Exception { - for(T value: values) { - label = start << shifter + taskId; + for (T value : values) { + label = (start << shifter) + taskId; - if(log2(start) + shifter < log2(Long.MAX_VALUE)) { + if (getBitSize(start) + shifter < maxBitSize) { out.collect(new Tuple2(label, value)); start++; } else { @@ -241,11 +253,7 @@ public class DataSetUtils { // UTIL METHODS // ************************************************************************* - private static int compareInts(int x, int y) { - return (x < y) ? -1 : ((x == y) ? 0 : 1); - } - - private static int log2(long value){ + public static int getBitSize(long value){ if(value > Integer.MAX_VALUE) { return 64 - Integer.numberOfLeadingZeros((int)(value >> 32)); } else { diff --git a/flink-tests/src/test/java/org/apache/flink/test/util/DataSetUtilsITCase.java b/flink-tests/src/test/java/org/apache/flink/test/util/DataSetUtilsITCase.java index 1e5363b47b7..a28911659d5 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/util/DataSetUtilsITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/util/DataSetUtilsITCase.java @@ -18,66 +18,59 @@ package org.apache.flink.test.util; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.utils.DataSetUtils; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; +import org.junit.Assert; import org.junit.Test; -import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Set; + @RunWith(Parameterized.class) public class DataSetUtilsITCase extends MultipleProgramsTestBase { - private String resultPath; - private String expectedResult; - - @Rule - public TemporaryFolder tempFolder = new TemporaryFolder(); - public DataSetUtilsITCase(TestExecutionMode mode) { super(mode); } - @Before - public void before() throws Exception{ - resultPath = tempFolder.newFile().toURI().toString(); - } - @Test public void testZipWithIndex() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); - env.setParallelism(1); - DataSet in = env.fromElements("A", "B", "C", "D", "E", "F"); - - DataSet> result = DataSetUtils.zipWithIndex(in); - - result.writeAsCsv(resultPath, "\n", ","); - env.execute(); - - expectedResult = "0,A\n" + "1,B\n" + "2,C\n" + "3,D\n" + "4,E\n" + "5,F"; + long expectedSize = 100L; + DataSet numbers = env.generateSequence(0, expectedSize - 1); + + List> result = Lists.newArrayList(DataSetUtils.zipWithIndex(numbers).collect()); + + Assert.assertEquals(expectedSize, result.size()); + // sort result by created index + Collections.sort(result, new Comparator>() { + @Override + public int compare(Tuple2 o1, Tuple2 o2) { + return o1.f0.compareTo(o2.f0); + } + }); + // test if index is consecutive + for (int i = 0; i < expectedSize; i++) { + Assert.assertEquals(i, (long) result.get(i).f0); + } } @Test public void testZipWithUniqueId() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); - env.setParallelism(1); - DataSet in = env.fromElements("A", "B", "C", "D", "E", "F"); - - DataSet> result = DataSetUtils.zipWithUniqueId(in); + long expectedSize = 100L; + DataSet numbers = env.generateSequence(1L, expectedSize); - result.writeAsCsv(resultPath, "\n", ","); - env.execute(); - - expectedResult = "0,A\n" + "2,B\n" + "4,C\n" + "6,D\n" + "8,E\n" + "10,F"; - } + Set> result = Sets.newHashSet(DataSetUtils.zipWithUniqueId(numbers).collect()); - @After - public void after() throws Exception{ - compareResultsByLinesInMemory(expectedResult, resultPath); + Assert.assertEquals(expectedSize, result.size()); } } -- GitLab