From 4672e95ef158bb5e52b1d493465d62429bbdb29e Mon Sep 17 00:00:00 2001 From: Aljoscha Krettek Date: Thu, 9 Apr 2015 14:05:08 +0200 Subject: [PATCH] [FLINK-1799][scala] Fix handling of generic arrays --- .../java/typeutils/ObjectArrayTypeInfo.java | 16 ++++ .../scala/codegen/TypeInformationGen.scala | 37 ++++++++- .../scala/types/TypeInformationGenTest.scala | 78 +++++++++++++++++-- 3 files changed, 119 insertions(+), 12 deletions(-) diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/ObjectArrayTypeInfo.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/ObjectArrayTypeInfo.java index 9be8408babb..76e1d2a5175 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/ObjectArrayTypeInfo.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/ObjectArrayTypeInfo.java @@ -18,6 +18,7 @@ package org.apache.flink.api.java.typeutils; +import java.lang.reflect.Array; import java.lang.reflect.GenericArrayType; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; @@ -143,6 +144,21 @@ public class ObjectArrayTypeInfo extends TypeInformation { throw new InvalidTypesException("The given type is not a valid object array."); } + /** + * Creates a new {@link org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo} from a + * {@link TypeInformation} for the component type. + * + *

+ * This must be used in cases where the complete type of the array is not available as a + * {@link java.lang.reflect.Type} or {@link java.lang.Class}. + */ + public static ObjectArrayTypeInfo getInfoFor(TypeInformation componentInfo) { + return new ObjectArrayTypeInfo( + Array.newInstance(componentInfo.getTypeClass(), 0).getClass(), + componentInfo.getTypeClass(), + componentInfo); + } + @SuppressWarnings("unchecked") public static ObjectArrayTypeInfo getInfoFor(Type type) { // class type e.g. for POJOs diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala index 118b8c466f6..0a8874b9597 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/codegen/TypeInformationGen.scala @@ -222,10 +222,39 @@ private[flink] trait TypeInformationGen[C <: Context] { } case _ => reify { - ObjectArrayTypeInfo.getInfoFor( - arrayClazz.splice, - elementTypeInfo.splice.asInstanceOf[TypeInformation[_]]) - .asInstanceOf[TypeInformation[T]] + val elementType = elementTypeInfo.splice.asInstanceOf[TypeInformation[_]] + val result = elementType match { + case BasicTypeInfo.BOOLEAN_TYPE_INFO => + PrimitiveArrayTypeInfo.BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO + + case BasicTypeInfo.BYTE_TYPE_INFO => + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO + + case BasicTypeInfo.CHAR_TYPE_INFO => + PrimitiveArrayTypeInfo.CHAR_PRIMITIVE_ARRAY_TYPE_INFO + + case BasicTypeInfo.DOUBLE_TYPE_INFO => + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO + + case BasicTypeInfo.FLOAT_TYPE_INFO => + PrimitiveArrayTypeInfo.FLOAT_PRIMITIVE_ARRAY_TYPE_INFO + + case BasicTypeInfo.INT_TYPE_INFO => + PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO + + case BasicTypeInfo.LONG_TYPE_INFO => + PrimitiveArrayTypeInfo.LONG_PRIMITIVE_ARRAY_TYPE_INFO + + case BasicTypeInfo.SHORT_TYPE_INFO => + PrimitiveArrayTypeInfo.SHORT_PRIMITIVE_ARRAY_TYPE_INFO + + case BasicTypeInfo.STRING_TYPE_INFO => + BasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO + + case _ => + ObjectArrayTypeInfo.getInfoFor(elementType) + } + result.asInstanceOf[TypeInformation[T]] } } } diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/types/TypeInformationGenTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/types/TypeInformationGenTest.scala index 08ba49d0c57..63d05bebd0c 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/types/TypeInformationGenTest.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/types/TypeInformationGenTest.scala @@ -17,18 +17,16 @@ */ package org.apache.flink.api.scala.types -import java.io.DataInput -import java.io.DataOutput +import java.io.{DataInput, DataOutput} + +import org.apache.hadoop.io.Writable +import org.junit.{Assert, Test} + import org.apache.flink.api.common.typeinfo._ -import org.apache.flink.api.common.typeutils._ import org.apache.flink.api.java.typeutils._ +import org.apache.flink.api.scala._ import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo import org.apache.flink.types.{IntValue, StringValue} -import org.apache.hadoop.io.Writable -import org.junit.Assert -import org.junit.Test - -import org.apache.flink.api.scala._ class MyWritable extends Writable { def write(out: DataOutput) { @@ -83,6 +81,70 @@ class TypeInformationGenTest { } + @Test + def testGenericArrays(): Unit = { + + class MyObject(var a: Int, var b: String) { + def this() = this(0, "") + } + + val boolArray = Array(true, false) + val byteArray = Array(1.toByte, 2.toByte, 3.toByte) + val charArray= Array(1.toChar, 2.toChar, 3.toChar) + val shortArray = Array(1.toShort, 2.toShort, 3.toShort) + val intArray = Array(1, 2, 3) + val longArray = Array(1L, 2L, 3L) + val floatArray = Array(1.0f, 2.0f, 3.0f) + val doubleArray = Array(1.0, 2.0, 3.0) + val stringArray = Array("hey", "there") + val objectArray = Array(new MyObject(1, "hey"), new MyObject(2, "there")) + + def getType[T: TypeInformation](arr: Array[T]): TypeInformation[Array[T]] = { + createTypeInformation[Array[T]] + } + + Assert.assertEquals( + PrimitiveArrayTypeInfo.BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO, + getType(boolArray)) + + Assert.assertEquals( + PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO, + getType(byteArray)) + + Assert.assertEquals( + PrimitiveArrayTypeInfo.CHAR_PRIMITIVE_ARRAY_TYPE_INFO, + getType(charArray)) + + Assert.assertEquals( + PrimitiveArrayTypeInfo.SHORT_PRIMITIVE_ARRAY_TYPE_INFO, + getType(shortArray)) + + Assert.assertEquals( + PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO, + getType(intArray)) + + Assert.assertEquals( + PrimitiveArrayTypeInfo.LONG_PRIMITIVE_ARRAY_TYPE_INFO, + getType(longArray)) + + Assert.assertEquals( + PrimitiveArrayTypeInfo.FLOAT_PRIMITIVE_ARRAY_TYPE_INFO, + getType(floatArray)) + + Assert.assertEquals( + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO, + getType(doubleArray)) + + Assert.assertEquals( + BasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO, + getType(stringArray)) + + Assert.assertTrue(getType(objectArray).isInstanceOf[ObjectArrayTypeInfo[_, _]]) + Assert.assertTrue( + getType(objectArray).asInstanceOf[ObjectArrayTypeInfo[_, _]] + .getComponentInfo.isInstanceOf[PojoTypeInfo[_]]) + } + @Test def testWritableType(): Unit = { val ti = createTypeInformation[MyWritable] -- GitLab