diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoComparator.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoComparator.java index 51d8090ad2ef9f272caf1c5bda2e4caf25e7d86f..2cccfcf65b8f293311fb2d32138d57d973d28669 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoComparator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoComparator.java @@ -340,14 +340,7 @@ public final class PojoComparator extends CompositeTypeComparator implemen public int extractKeys(Object record, Object[] target, int index) { int localIndex = index; for (int i = 0; i < comparators.length; i++) { - if(comparators[i] instanceof CompositeTypeComparator) { - localIndex += comparators[i].extractKeys(accessField(keyFields[i], record), target, localIndex) -1; - } else { - // non-composite case (= atomic). We can assume this to have only one key. - // comparators[i].extractKeys(accessField(keyFields[i], record), target, i); - target[localIndex] = accessField(keyFields[i], record); - } - localIndex++; + localIndex += comparators[i].extractKeys(accessField(keyFields[i], record), target, localIndex); } return localIndex - index; } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleComparator.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleComparator.java index 61a1567c8a66f6320776812fb7eccaa9f9082fc8..89b77945d962efae7de0fec3ac8becb8c6dcff42 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleComparator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleComparator.java @@ -147,14 +147,7 @@ public final class TupleComparator extends TupleComparatorBase< public int extractKeys(Object record, Object[] target, int index) { int localIndex = index; for(int i = 0; i < comparators.length; i++) { - // handle nested case - if(comparators[i] instanceof TupleComparator || comparators[i] instanceof PojoComparator) { - localIndex += comparators[i].extractKeys(((Tuple) record).getField(keyPositions[i]), target, localIndex) -1; - } else { - // flat - target[localIndex] = ((Tuple) record).getField(keyPositions[i]); - } - localIndex++; + localIndex += comparators[i].extractKeys(((Tuple) record).getField(keyPositions[i]), target, localIndex); } return localIndex - index; } diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassComparator.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassComparator.scala index 1353b44420b87890043ee14638c06ff1167be76a..bde009c4bb1a1a42467a1b735ff427184b6e726a 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassComparator.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassComparator.scala @@ -17,7 +17,8 @@ */ package org.apache.flink.api.scala.typeutils -import org.apache.flink.api.common.typeutils.{TypeComparator, TypeSerializer} +import org.apache.flink.api.common.typeutils.{CompositeTypeComparator, TypeComparator, +TypeSerializer} import org.apache.flink.api.java.typeutils.runtime.TupleComparatorBase import org.apache.flink.core.memory.MemorySegment import org.apache.flink.types.{KeyFieldOutOfBoundsException, NullKeyFieldException} @@ -140,9 +141,16 @@ class CaseClassComparator[T <: Product]( } def extractKeys(value: AnyRef, target: Array[AnyRef], index: Int) = { - for (i <- 0 until keyPositions.length ) { - target(index + i) = value.asInstanceOf[T].productElement(keyPositions(i)).asInstanceOf[AnyRef] + val in = value.asInstanceOf[T] + + var localIndex: Int = index + for (i <- 0 until comparators.length) { + localIndex += comparators(i).extractKeys( + in.productElement(keyPositions(i)), + target, + localIndex) } - keyPositions.length + + localIndex - index } }