提交 4672e95e 编写于 作者: A Aljoscha Krettek

[FLINK-1799][scala] Fix handling of generic arrays

上级 56cb7937
......@@ -18,6 +18,7 @@
package org.apache.flink.api.java.typeutils;
import java.lang.reflect.Array;
import java.lang.reflect.GenericArrayType;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
......@@ -143,6 +144,21 @@ public class ObjectArrayTypeInfo<T, C> extends TypeInformation<T> {
throw new InvalidTypesException("The given type is not a valid object array.");
}
/**
* Creates a new {@link org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo} from a
* {@link TypeInformation} for the component type.
*
* <p>
* This must be used in cases where the complete type of the array is not available as a
* {@link java.lang.reflect.Type} or {@link java.lang.Class}.
*/
public static <T, C> ObjectArrayTypeInfo<T, C> getInfoFor(TypeInformation<C> componentInfo) {
return new ObjectArrayTypeInfo<T, C>(
Array.newInstance(componentInfo.getTypeClass(), 0).getClass(),
componentInfo.getTypeClass(),
componentInfo);
}
@SuppressWarnings("unchecked")
public static <T, C> ObjectArrayTypeInfo<T, C> getInfoFor(Type type) {
// class type e.g. for POJOs
......
......@@ -222,10 +222,39 @@ private[flink] trait TypeInformationGen[C <: Context] {
}
case _ =>
reify {
ObjectArrayTypeInfo.getInfoFor(
arrayClazz.splice,
elementTypeInfo.splice.asInstanceOf[TypeInformation[_]])
.asInstanceOf[TypeInformation[T]]
val elementType = elementTypeInfo.splice.asInstanceOf[TypeInformation[_]]
val result = elementType match {
case BasicTypeInfo.BOOLEAN_TYPE_INFO =>
PrimitiveArrayTypeInfo.BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO
case BasicTypeInfo.BYTE_TYPE_INFO =>
PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO
case BasicTypeInfo.CHAR_TYPE_INFO =>
PrimitiveArrayTypeInfo.CHAR_PRIMITIVE_ARRAY_TYPE_INFO
case BasicTypeInfo.DOUBLE_TYPE_INFO =>
PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO
case BasicTypeInfo.FLOAT_TYPE_INFO =>
PrimitiveArrayTypeInfo.FLOAT_PRIMITIVE_ARRAY_TYPE_INFO
case BasicTypeInfo.INT_TYPE_INFO =>
PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO
case BasicTypeInfo.LONG_TYPE_INFO =>
PrimitiveArrayTypeInfo.LONG_PRIMITIVE_ARRAY_TYPE_INFO
case BasicTypeInfo.SHORT_TYPE_INFO =>
PrimitiveArrayTypeInfo.SHORT_PRIMITIVE_ARRAY_TYPE_INFO
case BasicTypeInfo.STRING_TYPE_INFO =>
BasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO
case _ =>
ObjectArrayTypeInfo.getInfoFor(elementType)
}
result.asInstanceOf[TypeInformation[T]]
}
}
}
......
......@@ -17,18 +17,16 @@
*/
package org.apache.flink.api.scala.types
import java.io.DataInput
import java.io.DataOutput
import java.io.{DataInput, DataOutput}
import org.apache.hadoop.io.Writable
import org.junit.{Assert, Test}
import org.apache.flink.api.common.typeinfo._
import org.apache.flink.api.common.typeutils._
import org.apache.flink.api.java.typeutils._
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
import org.apache.flink.types.{IntValue, StringValue}
import org.apache.hadoop.io.Writable
import org.junit.Assert
import org.junit.Test
import org.apache.flink.api.scala._
class MyWritable extends Writable {
def write(out: DataOutput) {
......@@ -83,6 +81,70 @@ class TypeInformationGenTest {
}
@Test
def testGenericArrays(): Unit = {
class MyObject(var a: Int, var b: String) {
def this() = this(0, "")
}
val boolArray = Array(true, false)
val byteArray = Array(1.toByte, 2.toByte, 3.toByte)
val charArray= Array(1.toChar, 2.toChar, 3.toChar)
val shortArray = Array(1.toShort, 2.toShort, 3.toShort)
val intArray = Array(1, 2, 3)
val longArray = Array(1L, 2L, 3L)
val floatArray = Array(1.0f, 2.0f, 3.0f)
val doubleArray = Array(1.0, 2.0, 3.0)
val stringArray = Array("hey", "there")
val objectArray = Array(new MyObject(1, "hey"), new MyObject(2, "there"))
def getType[T: TypeInformation](arr: Array[T]): TypeInformation[Array[T]] = {
createTypeInformation[Array[T]]
}
Assert.assertEquals(
PrimitiveArrayTypeInfo.BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO,
getType(boolArray))
Assert.assertEquals(
PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO,
getType(byteArray))
Assert.assertEquals(
PrimitiveArrayTypeInfo.CHAR_PRIMITIVE_ARRAY_TYPE_INFO,
getType(charArray))
Assert.assertEquals(
PrimitiveArrayTypeInfo.SHORT_PRIMITIVE_ARRAY_TYPE_INFO,
getType(shortArray))
Assert.assertEquals(
PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO,
getType(intArray))
Assert.assertEquals(
PrimitiveArrayTypeInfo.LONG_PRIMITIVE_ARRAY_TYPE_INFO,
getType(longArray))
Assert.assertEquals(
PrimitiveArrayTypeInfo.FLOAT_PRIMITIVE_ARRAY_TYPE_INFO,
getType(floatArray))
Assert.assertEquals(
PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
getType(doubleArray))
Assert.assertEquals(
BasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO,
getType(stringArray))
Assert.assertTrue(getType(objectArray).isInstanceOf[ObjectArrayTypeInfo[_, _]])
Assert.assertTrue(
getType(objectArray).asInstanceOf[ObjectArrayTypeInfo[_, _]]
.getComponentInfo.isInstanceOf[PojoTypeInfo[_]])
}
@Test
def testWritableType(): Unit = {
val ti = createTypeInformation[MyWritable]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册