提交 6a58aade 编写于 作者: T Till Rohrmann

[FLINK-2590] Fixes Scala's DataSetUtilsITCase

上级 ab14f901
......@@ -29,7 +29,7 @@ import _root_.scala.reflect.ClassTag
* or with a unique identifier.
*/
class DataSetUtils[T](val self: DataSet[T]) extends AnyVal {
class DataSetUtils[T](val self: DataSet[T]) {
/**
* Method that takes a set of subtask index, total number of elements mappings
......
......@@ -20,6 +20,7 @@ package org.apache.flink.test.util;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
......@@ -69,7 +70,14 @@ public class DataSetUtilsITCase extends MultipleProgramsTestBase {
long expectedSize = 100L;
DataSet<Long> numbers = env.generateSequence(1L, expectedSize);
Set<Tuple2<Long, Long>> result = Sets.newHashSet(DataSetUtils.zipWithUniqueId(numbers).collect());
DataSet<Long> ids = DataSetUtils.zipWithUniqueId(numbers).map(new MapFunction<Tuple2<Long,Long>, Long>() {
@Override
public Long map(Tuple2<Long, Long> value) throws Exception {
return value.f0;
}
});
Set<Long> result = Sets.newHashSet(ids.collect());
Assert.assertEquals(expectedSize, result.size());
}
......
......@@ -19,63 +19,46 @@
package org.apache.flink.api.scala.util
import org.apache.flink.api.scala._
import org.apache.flink.test.util.{MultipleProgramsTestBase, TestBaseUtils}
import org.junit.rules.TemporaryFolder
import org.apache.flink.test.util.{MultipleProgramsTestBase}
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.{After, Before, Rule, Test}
import org.junit._
import org.apache.flink.api.scala.DataSetUtils.utilsToDataSet
@RunWith(classOf[Parameterized])
class DataSetUtilsITCase (mode: MultipleProgramsTestBase.TestExecutionMode) extends
MultipleProgramsTestBase(mode){
private var resultPath: String = null
private var expectedResult: String = null
private val tempFolder: TemporaryFolder = new TemporaryFolder()
@Rule
def getFolder = tempFolder
@Before
@throws(classOf[Exception])
def before(): Unit = {
resultPath = tempFolder.newFile.toURI.toString
}
class DataSetUtilsITCase (
mode: MultipleProgramsTestBase.TestExecutionMode)
extends MultipleProgramsTestBase(mode) {
@Test
@throws(classOf[Exception])
def testZipWithIndex(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
env.setParallelism(1)
val input: DataSet[String] = env.fromElements("A", "B", "C", "D", "E", "F")
val result: DataSet[(Long, String)] = input.zipWithIndex
val expectedSize = 100L
result.writeAsCsv(resultPath, "\n", ",")
env.execute()
val numbers = env.generateSequence(0, expectedSize - 1)
expectedResult = "0,A\n" + "1,B\n" + "2,C\n" + "3,D\n" + "4,E\n" + "5,F"
val result = numbers.zipWithIndex.collect()
Assert.assertEquals(expectedSize, result.size)
for( ((index, _), expected) <- result.sortBy(_._1).zipWithIndex) {
Assert.assertEquals(expected, index)
}
}
@Test
@throws(classOf[Exception])
def testZipWithUniqueId(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
env.setParallelism(1)
val input: DataSet[String] = env.fromElements("A", "B", "C", "D", "E", "F")
val result: DataSet[(Long, String)] = input.zipWithUniqueId
val expectedSize = 100L
result.writeAsCsv(resultPath, "\n", ",")
env.execute()
val numbers = env.generateSequence(1L, expectedSize)
expectedResult = "0,A\n" + "2,B\n" + "4,C\n" + "6,D\n" + "8,E\n" + "10,F"
}
val result = numbers.zipWithUniqueId.collect().map(_._1).toSet
@After
@throws(classOf[Exception])
def after(): Unit = {
TestBaseUtils.compareResultsByLinesInMemory(expectedResult, resultPath)
Assert.assertEquals(expectedSize, result.size)
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册