提交 6b493fb0 编写于 作者: R Robert Metzger

Add Pojo support to Scala API

上级 aca6fbcd
......@@ -117,6 +117,7 @@ public abstract class CompositeType<T> extends TypeInformation<T> {
public int getPosition() {
return keyPosition;
}
public TypeInformation<?> getType() {
return type;
}
......
......@@ -62,14 +62,12 @@ public final class TupleTypeInfo<T extends Tuple> extends TupleTypeInfoBase<T> {
/**
* Comparator creation
*/
private TypeSerializer<?>[] fieldSerializers;
private TypeComparator<?>[] fieldComparators;
private int[] logicalKeyFields;
private int comparatorHelperIndex = 0;
@Override
protected void initializeNewComparator(int localKeyCount) {
fieldSerializers = new TypeSerializer[localKeyCount];
fieldComparators = new TypeComparator<?>[localKeyCount];
logicalKeyFields = new int[localKeyCount];
comparatorHelperIndex = 0;
......@@ -78,7 +76,6 @@ public final class TupleTypeInfo<T extends Tuple> extends TupleTypeInfoBase<T> {
@Override
protected void addCompareField(int fieldId, TypeComparator<?> comparator) {
fieldComparators[comparatorHelperIndex] = comparator;
fieldSerializers[comparatorHelperIndex] = types[fieldId].createSerializer();
logicalKeyFields[comparatorHelperIndex] = fieldId;
comparatorHelperIndex++;
}
......
......@@ -398,7 +398,7 @@ public class TypeExtractor {
return ObjectArrayTypeInfo.getInfoFor(t, componentInfo);
}
// objects with generics are treated as raw type
else if (t instanceof ParameterizedType) {
else if (t instanceof ParameterizedType) { //TODO
return privateGetForClass((Class<OUT>) ((ParameterizedType) t).getRawType(), typeHierarchy);
}
// no tuple, no TypeVariable, no generic type
......@@ -936,14 +936,13 @@ public class TypeExtractor {
return pojoType;
}
// return a generic type
return new GenericTypeInfo<X>(clazz);
}
/**
* Checks if the given field is a valid pojo field:
* - it is public
* - it is public
* OR
* - there are getter and setter methods for the field.
*
......@@ -968,8 +967,8 @@ public class TypeExtractor {
for(Method m : clazz.getMethods()) {
// check for getter
if( // The name should be "get<FieldName>".
m.getName().toLowerCase().contains("get"+fieldNameLow) &&
if( // The name should be "get<FieldName>" or "<fieldName>" (for scala).
(m.getName().toLowerCase().contains("get"+fieldNameLow) || m.getName().toLowerCase().contains(fieldNameLow)) &&
// no arguments for the getter
m.getParameterTypes().length == 0 &&
// return type is same as field type (or the generic variant of it)
......@@ -980,12 +979,12 @@ public class TypeExtractor {
}
hasGetter = true;
}
// check for setters
if( m.getName().toLowerCase().contains("set"+fieldNameLow) &&
m.getParameterTypes().length == 1 && // one parameter of the field's type
( m.getParameterTypes()[0].equals( fieldType ) || (fieldTypeGeneric != null && m.getGenericParameterTypes()[0].equals(fieldTypeGeneric) ) )&&
// return type is void.
m.getReturnType().equals(Void.TYPE)
// check for setters (<FieldName>_$eq for scala)
if((m.getName().toLowerCase().contains("set"+fieldNameLow) || m.getName().toLowerCase().contains(fieldNameLow+"_$eq")) &&
m.getParameterTypes().length == 1 && // one parameter of the field's type
( m.getParameterTypes()[0].equals( fieldType ) || (fieldTypeGeneric != null && m.getGenericParameterTypes()[0].equals(fieldTypeGeneric) ) )&&
// return type is void.
m.getReturnType().equals(Void.TYPE)
) {
if(hasSetter) {
throw new IllegalStateException("Detected more than one getters");
......@@ -993,7 +992,7 @@ public class TypeExtractor {
hasSetter = true;
}
}
if( hasGetter && hasSetter) {
if(hasGetter && hasSetter) {
return true;
} else {
if(!hasGetter) {
......
......@@ -340,7 +340,7 @@ public final class PojoComparator<T> extends CompositeTypeComparator<T> implemen
public int extractKeys(Object record, Object[] target, int index) {
int localIndex = index;
for (int i = 0; i < comparators.length; i++) {
if(comparators[i] instanceof PojoComparator || comparators[i] instanceof TupleComparator) {
if(comparators[i] instanceof CompositeTypeComparator) {
localIndex += comparators[i].extractKeys(accessField(keyFields[i], record), target, localIndex) -1;
} else {
// non-composite case (= atomic). We can assume this to have only one key.
......
......@@ -1247,7 +1247,7 @@ public class TypeExtractorTest {
public static class InType extends MyObject<String> {}
@SuppressWarnings({ "rawtypes", "unchecked" })
@Test
@Ignore
// @Ignore
public void testParamertizedCustomObject() {
RichMapFunction<?, ?> function = new RichMapFunction<InType, MyObject<String>>() {
private static final long serialVersionUID = 1L;
......
......@@ -31,7 +31,7 @@ import org.junit.Assert;
import org.junit.Ignore;
@Ignore // TODO
//@Ignore // TODO
public class PojoComparatorTest extends ComparatorTestBase<PojoContainingTuple> {
TypeInformation<PojoContainingTuple> type = TypeExtractor.getForClass(PojoContainingTuple.class);
......
......@@ -615,11 +615,11 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
* This only works on CaseClass DataSets.
*/
def groupBy(firstField: String, otherFields: String*): GroupedDataSet[T] = {
val fieldIndices = fieldNames2Indices(javaSet.getType, firstField +: otherFields.toArray)
// val fieldIndices = fieldNames2Indices(javaSet.getType, firstField +: otherFields.toArray)
new GroupedDataSet[T](
this,
new Keys.ExpressionKeys[T](fieldIndices, javaSet.getType,false))
new Keys.ExpressionKeys[T](firstField +: otherFields.toArray, javaSet.getType))
}
// public UnsortedGrouping<T> groupBy(String... fields) {
......
......@@ -32,68 +32,42 @@ abstract class CaseClassTypeInfo[T <: Product](
val fieldNames: Seq[String])
extends TupleTypeInfoBase[T](clazz, fieldTypes: _*) {
override def createComparator(logicalKeyFields: Array[Int],
orders: Array[Boolean], offset: Int): TypeComparator[T] = {
// sanity checks
if (logicalKeyFields == null || orders == null
|| logicalKeyFields.length != orders.length || logicalKeyFields.length > types.length) {
throw new IllegalArgumentException
}
// No special handling of leading Key field as in JavaTupleComparator for now
// --- general case ---
var maxKey: Int = -1
def getFieldIndices(fields: Array[String]): Array[Int] = {
fields map { x => fieldNames.indexOf(x) }
}
for (key <- logicalKeyFields) {
maxKey = Math.max(key, maxKey)
}
/*
* Comparator construction
*/
var fieldComparators: Array[TypeComparator[_]] = null
var logicalKeyFields : Array[Int] = null
var comparatorHelperIndex = 0
if (maxKey >= types.length) {
throw new IllegalArgumentException("The key position " + maxKey + " is out of range for " +
"Tuple" + types.length)
}
override protected def initializeNewComparator(localKeyCount: Int): Unit = {
fieldComparators = new Array(localKeyCount)
logicalKeyFields = new Array(localKeyCount)
comparatorHelperIndex = 0
}
// create the comparators for the individual fields
val fieldComparators: Array[TypeComparator[_]] = new Array(logicalKeyFields.length)
override protected def addCompareField(fieldId: Int, comparator: TypeComparator[_]): Unit = {
fieldComparators(comparatorHelperIndex) = comparator
logicalKeyFields(comparatorHelperIndex) = fieldId
comparatorHelperIndex += 1
}
for (i <- 0 until logicalKeyFields.length) {
val keyPos = logicalKeyFields(i)
if (types(keyPos).isKeyType && types(keyPos).isInstanceOf[AtomicType[_]]) {
fieldComparators(i) = types(keyPos).asInstanceOf[AtomicType[_]].createComparator(orders(i))
} else {
throw new IllegalArgumentException(
"The field at position " + i + " (" + types(keyPos) + ") is no atomic key type.")
}
override protected def getNewComparator: TypeComparator[T] = {
val finalLogicalKeyFields = logicalKeyFields.take(comparatorHelperIndex)
val finalComparators = fieldComparators.take(comparatorHelperIndex)
var maxKey: Int = 0
for (key <- finalLogicalKeyFields) {
maxKey = Math.max(maxKey, key)
}
// create the serializers for the prefix up to highest key position
val fieldSerializers: Array[TypeSerializer[_]] = new Array[TypeSerializer[_]](maxKey + 1)
for (i <- 0 to maxKey) {
fieldSerializers(i) = types(i).createSerializer
}
new CaseClassComparator[T](logicalKeyFields, fieldComparators, fieldSerializers)
}
def getFieldIndices(fields: Array[String]): Array[Int] = {
fields map { x => fieldNames.indexOf(x) }
}
override protected def initializeNewComparator(localKeyCount: Int): Unit = {
throw new UnsupportedOperationException("The Scala API is not using the composite " +
"type comparator creation")
}
override protected def getNewComparator: TypeComparator[T] = {
throw new UnsupportedOperationException("The Scala API is not using the composite " +
"type comparator creation")
}
override protected def addCompareField(fieldId: Int, comparator: TypeComparator[_]): Unit = {
throw new UnsupportedOperationException("The Scala API is not using the composite " +
"type comparator creation")
new CaseClassComparator[T](finalLogicalKeyFields, finalComparators, fieldSerializers)
}
override def toString = clazz.getSimpleName + "(" + fieldNames.zip(types).map {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.scala.operators
import org.apache.flink.api.scala.ExecutionEnvironment
import org.apache.flink.configuration.Configuration
import org.apache.flink.test.util.JavaProgramTestBase
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.apache.flink.api.scala._
import org.junit.runners.Parameterized.Parameters
import scala.collection.JavaConverters._
import scala.collection.mutable
// TODO case class Tuple2[T1, T2](_1: T1, _2: T2)
// TODO case class Foo(a: Int, b: String)
class Nested(var myLong: Long) {
def this() = {
this(0);
}
}
class Pojo(var myString: String, var myInt: Int, myLong: Long) {
var nested = new Nested(myLong)
def this() = {
this("", 0, 0)
}
override def toString() = "myString="+myString+" myInt="+myInt+" nested.myLong="+nested.myLong
}
object ExampleProgs {
var NUM_PROGRAMS: Int = 3
def runProgram(progId: Int, resultPath: String, onCollection: Boolean): String = {
progId match {
case 1 =>
/*
Test nested tuples with int offset
*/
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements( (("this","is"), 1), (("this", "is"),2), (("this","hello"),3) )
val grouped = ds.groupBy(0).reduce( { (e1, e2) => ((e1._1._1,e1._1._2), e1._2+e2._2)})
grouped.writeAsText(resultPath)
env.execute()
"((this,hello),3)\n((this,is),3)\n"
case 2 =>
/*
Test nested tuples with int offset
*/
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements( (("this","is"), 1), (("this", "is"),2), (("this","hello"),3) )
val grouped = ds.groupBy("f0.f0").reduce( { (e1, e2) => ((e1._1._1,e1._1._2), e1._2+e2._2)})
grouped.writeAsText(resultPath)
env.execute()
"((this,is),6)\n"
case 3 =>
/*
Test nested pojos
*/
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements( new Pojo("one", 1, 1L),new Pojo("one", 1, 1L),new Pojo("two", 666, 2L) )
val grouped = ds.groupBy("nested.myLong").reduce {
(p1, p2) =>
p1.myInt += p2.myInt
p1
}
grouped.writeAsText(resultPath)
env.execute()
"myString=two myInt=666 nested.myLong=2\nmyString=one myInt=2 nested.myLong=1\n"
}
}
}
@RunWith(classOf[Parameterized])
class ExamplesITCase(config: Configuration) extends JavaProgramTestBase(config) {
private var curProgId: Int = config.getInteger("ProgramId", -1)
private var resultPath: String = null
private var expectedResult: String = null
protected override def preSubmit(): Unit = {
resultPath = getTempDirPath("result")
}
protected def testProgram(): Unit = {
expectedResult = ExampleProgs.runProgram(curProgId, resultPath, isCollectionExecution)
}
protected override def postSubmit(): Unit = {
compareResultsByLinesInMemory(expectedResult, resultPath)
}
}
object ExamplesITCase {
@Parameters
def getConfigurations: java.util.Collection[Array[AnyRef]] = {
val configs = mutable.MutableList[Array[AnyRef]]()
for (i <- 1 to ExampleProgs.NUM_PROGRAMS) {
val config = new Configuration()
config.setInteger("ProgramId", i)
configs += Array(config)
}
configs.asJavaCollection
}
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册