提交 c8c50be8 编写于 作者: A Aljoscha Krettek 提交者: Stephan Ewen

[FLINK-1378] [scala] Add support for Try[A] (Success/Failure)

This closes #293
上级 8af6ef49
......@@ -69,6 +69,8 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
case EitherType(leftTpe, rightTpe) => analyzeEither(id, tpe, leftTpe, rightTpe)
case TryType(elemTpe) => analyzeTry(id, tpe, elemTpe)
case OptionType(elemTpe) => analyzeOption(id, tpe, elemTpe)
case CaseClassType() => analyzeCaseClass(id, tpe)
......@@ -116,6 +118,14 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
}
}
private def analyzeTry(
id: Int,
tpe: Type,
elemTpe: Type): UDTDescriptor = analyze(elemTpe) match {
case UnsupportedDescriptor(_, _, errs) => UnsupportedDescriptor(id, tpe, errs)
case elemDesc => TryDescriptor(id, tpe, elemDesc)
}
private def analyzeOption(
id: Int,
tpe: Type,
......@@ -310,6 +320,20 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
}
}
private object TryType {
def unapply(tpe: Type): Option[Type] = {
if (tpe <:< typeOf[scala.util.Try[_]]) {
val option = tpe.baseType(typeOf[scala.util.Try[_]].typeSymbol)
option match {
case TypeRef(_, _, elemTpe :: Nil) =>
Some(elemTpe)
}
} else {
None
}
}
}
private object OptionType {
def unapply(tpe: Type): Option[Type] = {
if (tpe <:< typeOf[Option[_]]) {
......
......@@ -100,6 +100,13 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C]
override def canBeKey = false
}
case class TryDescriptor(id: Int, tpe: Type, elem: UDTDescriptor)
extends UDTDescriptor {
override val isPrimitiveProduct = false
override def flatten = Seq(this)
override def canBeKey = false
}
case class OptionDescriptor(id: Int, tpe: Type, elem: UDTDescriptor)
extends UDTDescriptor {
override val isPrimitiveProduct = false
......
......@@ -63,6 +63,8 @@ private[flink] trait TypeInformationGen[C <: Context] {
case e: EitherDescriptor => mkEitherTypeInfo(e)
case tr: TryDescriptor => mkTryTypeInfo(tr)
case o: OptionDescriptor => mkOptionTypeInfo(o)
case a : ArrayDescriptor => mkArrayTypeInfo(a)
......@@ -129,6 +131,19 @@ private[flink] trait TypeInformationGen[C <: Context] {
c.Expr[TypeInformation[T]](result)
}
def mkTryTypeInfo[T: c.WeakTypeTag](desc: TryDescriptor): c.Expr[TypeInformation[T]] = {
val elemTypeInfo = mkTypeInfo(desc.elem)(c.WeakTypeTag(desc.elem.tpe))
val result = q"""
import org.apache.flink.api.scala.typeutils.TryTypeInfo
new TryTypeInfo[${desc.elem.tpe}, ${desc.tpe}]($elemTypeInfo)
"""
c.Expr[TypeInformation[T]](result)
}
def mkOptionTypeInfo[T: c.WeakTypeTag](desc: OptionDescriptor): c.Expr[TypeInformation[T]] = {
val elemTypeInfo = mkTypeInfo(desc.elem)(c.WeakTypeTag(desc.elem.tpe))
......
......@@ -21,7 +21,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.common.typeutils.TypeSerializer
/**
* TypeInformation [[Option]].
* TypeInformation for [[Option]].
*/
class OptionTypeInfo[A, T <: Option[A]](elemTypeInfo: TypeInformation[A])
extends TypeInformation[T] {
......@@ -35,7 +35,7 @@ class OptionTypeInfo[A, T <: Option[A]](elemTypeInfo: TypeInformation[A])
def createSerializer(): TypeSerializer[T] = {
if (elemTypeInfo == null) {
// this happens when the type of a DataSet is None
// this happens when the type of a DataSet is None, i.e. DataSet[None]
new OptionSerializer(new NothingSerializer).asInstanceOf[TypeSerializer[T]]
} else {
new OptionSerializer(elemTypeInfo.createSerializer()).asInstanceOf[TypeSerializer[T]]
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala.typeutils
import org.apache.flink.api.common.typeutils.TypeSerializer
import org.apache.flink.api.java.typeutils.runtime.KryoSerializer
import org.apache.flink.core.memory.{DataInputView, DataOutputView}
import scala.util.{Success, Try, Failure}
/**
* Serializer for [[scala.util.Try]].
*/
class TrySerializer[A](val elemSerializer: TypeSerializer[A])
extends TypeSerializer[Try[A]] {
override def isStateful: Boolean = false
val throwableSerializer = new KryoSerializer[Throwable](classOf[Throwable])
override def createInstance: Try[A] = {
Failure(new RuntimeException("Empty Failure"))
}
override def isImmutableType: Boolean = elemSerializer == null || elemSerializer.isImmutableType
override def getLength: Int = -1
override def copy(from: Try[A]): Try[A] = from match {
case Success(a) => Success(elemSerializer.copy(a))
case Failure(t) => Failure(throwableSerializer.copy(t))
}
override def copy(from: Try[A], reuse: Try[A]): Try[A] = copy(from)
override def copy(source: DataInputView, target: DataOutputView): Unit = {
val isSuccess = source.readBoolean()
target.writeBoolean(isSuccess)
if (isSuccess) {
elemSerializer.copy(source, target)
} else {
throwableSerializer.copy(source, target)
}
}
override def serialize(either: Try[A], target: DataOutputView): Unit = either match {
case Success(a) =>
target.writeBoolean(true)
elemSerializer.serialize(a, target)
case Failure(t) =>
target.writeBoolean(false)
throwableSerializer.serialize(t, target)
}
override def deserialize(source: DataInputView): Try[A] = {
val isSuccess = source.readBoolean()
if (isSuccess) {
Success(elemSerializer.deserialize(source))
} else {
Failure(throwableSerializer.deserialize(source))
}
}
override def deserialize(reuse: Try[A], source: DataInputView): Try[A] = deserialize(source)
override def equals(obj: Any): Boolean = {
if (obj != null && obj.isInstanceOf[TrySerializer[_]]) {
val other = obj.asInstanceOf[TrySerializer[_]]
other.elemSerializer.equals(elemSerializer)
} else {
false
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala.typeutils
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.common.typeutils.TypeSerializer
import org.apache.flink.api.java.typeutils.runtime.KryoSerializer
import scala.util.Try
/**
* TypeInformation for [[scala.util.Try]].
*/
class TryTypeInfo[A, T <: Try[A]](elemTypeInfo: TypeInformation[A])
extends TypeInformation[T] {
override def isBasicType: Boolean = false
override def isTupleType: Boolean = false
override def isKeyType: Boolean = false
override def getTotalFields: Int = 1
override def getArity: Int = 1
override def getTypeClass = classOf[Try[_]].asInstanceOf[Class[T]]
def createSerializer(): TypeSerializer[T] = {
if (elemTypeInfo == null) {
// this happens when the type of a DataSet is None, i.e. DataSet[Failure]
new TrySerializer(new NothingSerializer).asInstanceOf[TypeSerializer[T]]
} else {
new TrySerializer(elemTypeInfo.createSerializer()).asInstanceOf[TypeSerializer[T]]
}
}
override def toString = s"Try[$elemTypeInfo]"
}
......@@ -28,6 +28,8 @@ import org.apache.flink.api.scala._
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import scala.util.{Failure, Success}
@RunWith(classOf[Parameterized])
class ScalaSpecialTypesITCase(mode: ExecutionMode) extends MultipleProgramsTestBase(mode) {
......@@ -100,6 +102,71 @@ class ScalaSpecialTypesITCase(mode: ExecutionMode) extends MultipleProgramsTestB
compareResultsByLinesInMemory("60", resultPath)
}
@Test
def testTry1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val nums = env.fromElements(1, 2, 1, 2)
val trys = nums.map(_ match {
case 1 => Success(10)
case 2 => Failure(new RuntimeException("20"))
})
val resultPath = tempFolder.newFile().toURI.toString
val result = trys.map{
_ match {
case Success(i) => i
case Failure(t) => t.getMessage.toInt
}}.reduce(_ + _).writeAsText(resultPath, WriteMode.OVERWRITE)
env.execute()
compareResultsByLinesInMemory("60", resultPath)
}
@Test
def testTry2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val nums = env.fromElements(1, 2, 1, 2)
val trys = nums.map(_ match {
case 1 => Success(10)
case 2 => Success(20)
})
val resultPath = tempFolder.newFile().toURI.toString
val result = trys.map(_ match {
case Success(i) => i
}).reduce(_ + _).writeAsText(resultPath, WriteMode.OVERWRITE)
env.execute()
compareResultsByLinesInMemory("60", resultPath)
}
@Test
def testTry3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val nums = env.fromElements(1, 2, 1, 2)
val trys = nums.map(_ match {
case 1 => Failure(new RuntimeException("10"))
case 2 => Failure(new IllegalAccessError("20"))
})
val resultPath = tempFolder.newFile().toURI.toString
val result = trys.map(_ match {
case Failure(t) => t.getMessage.toInt
}).reduce(_ + _).writeAsText(resultPath, WriteMode.OVERWRITE)
env.execute()
compareResultsByLinesInMemory("60", resultPath)
}
@Test
def testOption1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
......
......@@ -25,6 +25,8 @@ import org.junit.{Assert, Test}
import org.apache.flink.api.scala._
import scala.util.{Failure, Success}
class ScalaSpecialTypesSerializerTest {
@Test
......@@ -63,6 +65,26 @@ class ScalaSpecialTypesSerializerTest {
runTests(testData)
}
@Test
def testTry(): Unit = {
val testData = Array(Success("Hell"), Failure(new RuntimeException("test")))
runTests(testData)
}
@Test
def testSuccess(): Unit = {
val testData = Array(Success("Hell"), Success("Yeah"))
runTests(testData)
}
@Test
def testFailure(): Unit = {
val testData = Array(
Failure(new RuntimeException("test")),
Failure(new RuntimeException("one, two")))
runTests(testData)
}
private final def runTests[T : TypeInformation](instances: Array[T]) {
try {
......@@ -83,10 +105,10 @@ class ScalaSpecialTypesSerializerTest {
}
class ScalaSpecialTypesSerializerTestInstance[T](
serializer: TypeSerializer[T],
typeClass: Class[T],
length: Int,
testData: Array[T])
serializer: TypeSerializer[T],
typeClass: Class[T],
length: Int,
testData: Array[T])
extends SerializerTestInstance[T](serializer, typeClass, length, testData: _*) {
@Test
......@@ -123,6 +145,9 @@ class ScalaSpecialTypesSerializerTestInstance[T](
assertEquals(message, should, is)
}
case Failure(t) =>
is.asInstanceOf[Failure[_]].exception.equals(t)
case _ =>
super.deepEquals(message, should, is)
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册