From d2e9aeb55d5c7bfaa65f92163a142f53200711ef Mon Sep 17 00:00:00 2001 From: Timo Walther Date: Tue, 15 Dec 2020 16:43:23 +0100 Subject: [PATCH] [FLINK-19981][core][table] Add name-based field mode for Row This adds a name-based field mode to the Row class. A row can operate in 3 different modes: name-based, position-based, or a hybrid of both when leaving the Flink runtime. It simplifies the handling of large rows (possibly with hundreds of fields) and will make it easier to switch between DataStream API and Table API. See the documentation of the Row class for more information. This closes #14420. --- .../flink/api/java/typeutils/RowTypeInfo.java | 9 +- .../java/typeutils/runtime/RowSerializer.java | 228 ++++++++-- .../main/java/org/apache/flink/types/Row.java | 422 +++++++++++++++--- .../java/org/apache/flink/types/RowUtils.java | 178 +++++++- .../typeutils/runtime/RowSerializerTest.java | 56 ++- .../flink/testutils/DeeplyEqualsChecker.java | 22 - .../java/org/apache/flink/types/RowTest.java | 334 +++++++++++++- .../api/utils/PythonTypeUtilsTest.java | 2 +- .../sources/RowArrowSourceFunctionTest.java | 2 +- .../planner/functions/RowFunctionITCase.java | 2 +- .../table/TemporalTableFunctionJoinTest.xml | 8 +- .../data/conversion/RowRowConverter.java | 53 ++- .../data/DataStructureConvertersTest.java | 10 + 13 files changed, 1132 insertions(+), 194 deletions(-) diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/RowTypeInfo.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/RowTypeInfo.java index 666726d0763..536bc7df214 100644 --- a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/RowTypeInfo.java +++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/RowTypeInfo.java @@ -32,6 +32,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -253,7 +254,11 @@ public class RowTypeInfo extends TupleTypeInfoBase { for (int i = 0; i < len; i++) { fieldSerializers[i] = types[i].createSerializer(config); } - return new RowSerializer(fieldSerializers); + final LinkedHashMap positionByName = new LinkedHashMap<>(); + for (int i = 0; i < fieldNames.length; i++) { + positionByName.put(fieldNames[i], i); + } + return new RowSerializer(fieldSerializers, positionByName); } @Override @@ -308,7 +313,7 @@ public class RowTypeInfo extends TupleTypeInfoBase { for (int i = 0; i < len; i++) { fieldSerializers[i] = types[i].createSerializer(config); } - return new RowSerializer(fieldSerializers, true); + return new RowSerializer(fieldSerializers, null, true); } /** Tests whether an other object describes the same, schema-equivalent row information. */ diff --git a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/RowSerializer.java b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/RowSerializer.java index d1e3bcd486f..d9bfc677441 100644 --- a/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/RowSerializer.java +++ b/flink-core/src/main/java/org/apache/flink/api/java/typeutils/runtime/RowSerializer.java @@ -25,15 +25,21 @@ import org.apache.flink.api.common.typeutils.CompositeTypeSerializerUtil; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility; import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.types.Row; import org.apache.flink.types.RowKind; +import org.apache.flink.types.RowUtils; + +import javax.annotation.Nullable; import java.io.IOException; import java.io.ObjectInputStream; import java.util.Arrays; +import java.util.LinkedHashMap; import java.util.Objects; +import java.util.Set; import static org.apache.flink.api.java.typeutils.runtime.MaskUtils.readIntoAndCopyMask; import static org.apache.flink.api.java.typeutils.runtime.MaskUtils.readIntoMask; @@ -57,6 +63,11 @@ import static org.apache.flink.util.Preconditions.checkNotNull; * bitmask with row kind: |RK RK F1 F2 ... FN| * bitmask in legacy mode: |F1 F2 ... FN| * + * + *

Field names are an optional part of this serializer. They allow to use rows in named-based + * field mode. However, the support for name-based rows is limited. Usually, name-based mode should + * not be used in state but only for in-flight data. For now, names are not part of serializer + * snapshot or equals/hashCode (similar to {@link RowTypeInfo}). */ @Internal public final class RowSerializer extends TypeSerializer { @@ -74,19 +85,34 @@ public final class RowSerializer extends TypeSerializer { private final int arity; + private final @Nullable LinkedHashMap positionByName; + private transient boolean[] mask; + private transient Row reuseRowPositionBased; + public RowSerializer(TypeSerializer[] fieldSerializers) { - this(fieldSerializers, false); + this(fieldSerializers, null, false); + } + + public RowSerializer( + TypeSerializer[] fieldSerializers, + @Nullable LinkedHashMap positionByName) { + this(fieldSerializers, positionByName, false); } @SuppressWarnings("unchecked") - public RowSerializer(TypeSerializer[] fieldSerializers, boolean legacyModeEnabled) { + public RowSerializer( + TypeSerializer[] fieldSerializers, + @Nullable LinkedHashMap positionByName, + boolean legacyModeEnabled) { this.legacyModeEnabled = legacyModeEnabled; this.legacyOffset = legacyModeEnabled ? 0 : ROW_KIND_OFFSET; this.fieldSerializers = (TypeSerializer[]) checkNotNull(fieldSerializers); this.arity = fieldSerializers.length; + this.positionByName = positionByName; this.mask = new boolean[legacyOffset + fieldSerializers.length]; + this.reuseRowPositionBased = new Row(fieldSerializers.length); } @Override @@ -100,71 +126,112 @@ public final class RowSerializer extends TypeSerializer { for (int i = 0; i < fieldSerializers.length; i++) { duplicateFieldSerializers[i] = fieldSerializers[i].duplicate(); } - return new RowSerializer(duplicateFieldSerializers, legacyModeEnabled); + return new RowSerializer(duplicateFieldSerializers, positionByName, legacyModeEnabled); } @Override public Row createInstance() { - return new Row(fieldSerializers.length); + return RowUtils.createRowWithNamedPositions( + RowKind.INSERT, new Object[fieldSerializers.length], positionByName); } @Override public Row copy(Row from) { - int len = fieldSerializers.length; + final Set fieldNames = from.getFieldNames(false); + if (fieldNames == null) { + return copyPositionBased(from); + } else { + return copyNameBased(from, fieldNames); + } + } - if (from.getArity() != len) { + private Row copyPositionBased(Row from) { + final int length = fieldSerializers.length; + if (from.getArity() != length) { throw new RuntimeException( "Row arity of from (" + from.getArity() - + ") does not match this serializers field length (" - + len + + ") does not match " + + "this serializer's field length (" + + length + ")."); } + final Object[] fieldByPosition = new Object[length]; + for (int i = 0; i < length; i++) { + final Object fromField = from.getField(i); + if (fromField != null) { + final Object copy = fieldSerializers[i].copy(fromField); + fieldByPosition[i] = copy; + } + } + return RowUtils.createRowWithNamedPositions( + from.getKind(), fieldByPosition, positionByName); + } - Row result = new Row(from.getKind(), len); - for (int i = 0; i < len; i++) { - Object fromField = from.getField(i); + private Row copyNameBased(Row from, Set fieldNames) { + if (positionByName == null) { + throw new RuntimeException("Serializer does not support named field positions."); + } + final Row newRow = Row.withNames(from.getKind()); + for (String fieldName : fieldNames) { + final int targetPos = getPositionByName(fieldName); + final Object fromField = from.getField(fieldName); if (fromField != null) { - Object copy = fieldSerializers[i].copy(fromField); - result.setField(i, copy); + final Object copy = fieldSerializers[targetPos].copy(fromField); + newRow.setField(fieldName, copy); } else { - result.setField(i, null); + newRow.setField(fieldName, null); } } - return result; + return newRow; } @Override public Row copy(Row from, Row reuse) { - int len = fieldSerializers.length; - // cannot reuse, do a non-reuse copy if (reuse == null) { return copy(from); } - if (from.getArity() != len || reuse.getArity() != len) { + final Set fieldNames = from.getFieldNames(false); + if (fieldNames == null) { + // reuse uses name-based field mode, do a non-reuse copy + if (reuse.getFieldNames(false) != null) { + return copy(from); + } + return copyPositionBased(from, reuse); + } else { + // reuse uses position-based field mode, do a non-reuse copy + if (reuse.getFieldNames(false) == null) { + return copy(from); + } + return copyNameBased(from, fieldNames, reuse); + } + } + + private Row copyPositionBased(Row from, Row reuse) { + final int length = fieldSerializers.length; + if (from.getArity() != length || reuse.getArity() != length) { throw new RuntimeException( "Row arity of reuse (" + reuse.getArity() + ") or from (" + from.getArity() - + ") is incompatible with this serializers field length (" - + len + + ") is " + + "incompatible with this serializer's field length (" + + length + ")."); } - reuse.setKind(from.getKind()); - - for (int i = 0; i < len; i++) { - Object fromField = from.getField(i); + for (int i = 0; i < length; i++) { + final Object fromField = from.getField(i); if (fromField != null) { - Object reuseField = reuse.getField(i); + final Object reuseField = reuse.getField(i); if (reuseField != null) { - Object copy = fieldSerializers[i].copy(fromField, reuseField); + final Object copy = fieldSerializers[i].copy(fromField, reuseField); reuse.setField(i, copy); } else { - Object copy = fieldSerializers[i].copy(fromField); + final Object copy = fieldSerializers[i].copy(fromField); reuse.setField(i, copy); } } else { @@ -174,6 +241,29 @@ public final class RowSerializer extends TypeSerializer { return reuse; } + private Row copyNameBased(Row from, Set fieldNames, Row reuse) { + if (positionByName == null) { + throw new RuntimeException("Serializer does not support named field positions."); + } + reuse.clear(); + reuse.setKind(from.getKind()); + for (String fieldName : fieldNames) { + final int targetPos = getPositionByName(fieldName); + final Object fromField = from.getField(fieldName); + if (fromField != null) { + final Object reuseField = reuse.getField(fieldName); + if (reuseField != null) { + final Object copy = fieldSerializers[targetPos].copy(fromField, reuseField); + reuse.setField(fieldName, copy); + } else { + final Object copy = fieldSerializers[targetPos].copy(fromField); + reuse.setField(fieldName, copy); + } + } + } + return reuse; + } + @Override public int getLength() { return -1; @@ -185,23 +275,32 @@ public final class RowSerializer extends TypeSerializer { @Override public void serialize(Row record, DataOutputView target) throws IOException { - final int len = fieldSerializers.length; + final Set fieldNames = record.getFieldNames(false); + if (fieldNames == null) { + serializePositionBased(record, target); + } else { + serializeNameBased(record, fieldNames, target); + } + } - if (record.getArity() != len) { + private void serializePositionBased(Row record, DataOutputView target) throws IOException { + final int length = fieldSerializers.length; + if (record.getArity() != length) { throw new RuntimeException( "Row arity of record (" + record.getArity() - + ") does not match this serializers field length (" - + len + + ") does not match this " + + "serializer's field length (" + + length + ")."); } // write bitmask - fillMask(len, record, mask, legacyModeEnabled, legacyOffset); + fillMask(length, record, mask, legacyModeEnabled, legacyOffset); writeMask(mask, target); // serialize non-null fields - for (int fieldPos = 0; fieldPos < len; fieldPos++) { + for (int fieldPos = 0; fieldPos < length; fieldPos++) { final Object o = record.getField(fieldPos); if (o != null) { fieldSerializers[fieldPos].serialize(o, target); @@ -209,39 +308,62 @@ public final class RowSerializer extends TypeSerializer { } } + private void serializeNameBased(Row record, Set fieldNames, DataOutputView target) + throws IOException { + if (positionByName == null) { + throw new RuntimeException("Serializer does not support named field positions."); + } + reuseRowPositionBased.clear(); + reuseRowPositionBased.setKind(record.getKind()); + for (String fieldName : fieldNames) { + final int targetPos = getPositionByName(fieldName); + final Object value = record.getField(fieldName); + reuseRowPositionBased.setField(targetPos, value); + } + serializePositionBased(reuseRowPositionBased, target); + } + @Override public Row deserialize(DataInputView source) throws IOException { - final int len = fieldSerializers.length; + final int length = fieldSerializers.length; // read bitmask readIntoMask(source, mask); - final Row result; + + // read row kind + final RowKind kind; if (legacyModeEnabled) { - result = new Row(len); + kind = RowKind.INSERT; } else { - result = new Row(readKindFromMask(mask), len); + kind = readKindFromMask(mask); } // deserialize fields - for (int fieldPos = 0; fieldPos < len; fieldPos++) { + final Object[] fieldByPosition = new Object[length]; + for (int fieldPos = 0; fieldPos < length; fieldPos++) { if (!mask[legacyOffset + fieldPos]) { - result.setField(fieldPos, fieldSerializers[fieldPos].deserialize(source)); + fieldByPosition[fieldPos] = fieldSerializers[fieldPos].deserialize(source); } } - return result; + return RowUtils.createRowWithNamedPositions(kind, fieldByPosition, positionByName); } @Override public Row deserialize(Row reuse, DataInputView source) throws IOException { - final int len = fieldSerializers.length; + // reuse uses name-based field mode, do a non-reuse deserialize + if (reuse == null || reuse.getFieldNames(false) != null) { + return deserialize(source); + } + final int length = fieldSerializers.length; - if (reuse.getArity() != len) { + if (reuse.getArity() != length) { throw new RuntimeException( "Row arity of reuse (" + reuse.getArity() - + ") does not match this serializers field length (" - + len + + ") does not match " + + "this serializer's field length (" + + length + ")."); } @@ -252,7 +374,7 @@ public final class RowSerializer extends TypeSerializer { } // deserialize fields - for (int fieldPos = 0; fieldPos < len; fieldPos++) { + for (int fieldPos = 0; fieldPos < length; fieldPos++) { if (mask[legacyOffset + fieldPos]) { reuse.setField(fieldPos, null); } else { @@ -306,9 +428,23 @@ public final class RowSerializer extends TypeSerializer { // -------------------------------------------------------------------------------------------- + private int getPositionByName(String fieldName) { + assert positionByName != null; + final Integer targetPos = positionByName.get(fieldName); + if (targetPos == null) { + throw new RuntimeException( + String.format( + "Unknown field name '%s' for mapping to a row position. " + + "Available names are: %s", + fieldName, positionByName.keySet())); + } + return targetPos; + } + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { in.defaultReadObject(); this.mask = new boolean[legacyOffset + fieldSerializers.length]; + this.reuseRowPositionBased = new Row(fieldSerializers.length); } // -------------------------------------------------------------------------------------------- @@ -426,7 +562,7 @@ public final class RowSerializer extends TypeSerializer { protected RowSerializer createOuterSerializerWithNestedSerializers( TypeSerializer[] nestedSerializers) { return new RowSerializer( - nestedSerializers, readVersion <= LAST_VERSION_WITHOUT_ROW_KIND); + nestedSerializers, null, readVersion <= LAST_VERSION_WITHOUT_ROW_KIND); } } } diff --git a/flink-core/src/main/java/org/apache/flink/types/Row.java b/flink-core/src/main/java/org/apache/flink/types/Row.java index d2d37dd2d64..b86c09a6898 100644 --- a/flink-core/src/main/java/org/apache/flink/types/Row.java +++ b/flink-core/src/main/java/org/apache/flink/types/Row.java @@ -18,12 +18,17 @@ package org.apache.flink.types; import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.java.typeutils.runtime.RowSerializer; import org.apache.flink.util.Preconditions; -import org.apache.flink.util.StringUtils; import javax.annotation.Nullable; import java.io.Serializable; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; import static org.apache.flink.types.RowUtils.deepEqualsRow; import static org.apache.flink.types.RowUtils.deepHashCodeRow; @@ -38,11 +43,38 @@ import static org.apache.flink.types.RowUtils.deepHashCodeRow; * Therefore, a row does not only consist of a schema part (containing the fields) but also attaches * a {@link RowKind} for encoding a change in a changelog. Thus, a row can be considered as an entry * in a changelog. For example, in regular batch scenarios, a changelog would consist of a bounded - * stream of {@link RowKind#INSERT} rows. + * stream of {@link RowKind#INSERT} rows. The row kind is kept separate from the fields and can be + * accessed by using {@link #getKind()} and {@link #setKind(RowKind)}. * - *

The fields of a row can be accessed by position (zero-based) using {@link #getField(int)} and - * {@link #setField(int, Object)}. The row kind is kept separate from the fields and can be accessed - * by using {@link #getKind()} and {@link #setKind(RowKind)}. + *

Fields of a row can be accessed either position-based or name-based. An implementer can decide + * in which field mode a row should operate during creation. Rows that were produced by the + * framework support a hybrid of both field modes (i.e. named positions): + * + *

Position-based field mode

+ * + *

{@link Row#withPositions(int)} creates a fixed-length row. The fields can be accessed by + * position (zero-based) using {@link #getField(int)} and {@link #setField(int, Object)}. Every + * field is initialized with {@code null} by default. + * + *

Name-based field mode

+ * + *

{@link Row#withNames()} creates a variable-length row. The fields can be accessed by name + * using {@link #getField(String)} and {@link #setField(String, Object)}. Every field is initialized + * during the first call to {@link #setField(String, Object)} for the given name. However, the + * framework will initialize missing fields with {@code null} and reorder all fields once more type + * information is available during serialization or input conversion. Thus, even name-based rows + * eventually become fixed-length composite types with a deterministic field order. Name-based rows + * perform worse than position-based rows but simplify row creation and code readability. + * + *

Hybrid / named-position field mode

+ * + *

Rows that were produced by the framework (after deserialization or output conversion) are + * fixed-length rows with a deterministic field order that can map static field names to field + * positions. Thus, fields can be accessed both via {@link #getField(int)} and {@link + * #getField(String)}. Both {@link #setField(int, Object)} and {@link #setField(String, Object)} are + * supported for existing fields. However, adding new field names via {@link #setField(String, + * Object)} is not allowed. A hybrid row's {@link #equals(Object)} supports comparing to all kinds + * of rows. A hybrid row's {@link #hashCode()} is only valid for position-based rows. * *

A row instance is in principle {@link Serializable}. However, it may contain non-serializable * fields in which case serialization will fail if the row is not serialized with Flink's @@ -54,38 +86,123 @@ import static org.apache.flink.types.RowUtils.deepHashCodeRow; @PublicEvolving public final class Row implements Serializable { - private static final long serialVersionUID = 2L; + private static final long serialVersionUID = 3L; /** The kind of change a row describes in a changelog. */ private RowKind kind; - /** The array to store actual values. */ - private final Object[] fields; + /** Fields organized by position. Either this or {@link #fieldByName} is set. */ + private final @Nullable Object[] fieldByPosition; + + /** Fields organized by name. Either this or {@link #fieldByPosition} is set. */ + private final @Nullable Map fieldByName; + + /** Mapping from field names to positions. Requires {@link #fieldByPosition} semantics. */ + private final @Nullable LinkedHashMap positionByName; + + Row( + RowKind kind, + @Nullable Object[] fieldByPosition, + @Nullable Map fieldByName, + @Nullable LinkedHashMap positionByName) { + this.kind = kind; + this.fieldByPosition = fieldByPosition; + this.fieldByName = fieldByName; + this.positionByName = positionByName; + } /** - * Create a new row instance. + * Creates a fixed-length row in position-based field mode. * - *

By default, a row describes an {@link RowKind#INSERT} change. + *

The semantics are equivalent to {@link Row#withPositions(RowKind, int)}. This constructor + * exists for backwards compatibility. * * @param kind kind of change a row describes in a changelog - * @param arity The number of fields in the row. + * @param arity the number of fields in the row */ public Row(RowKind kind, int arity) { this.kind = Preconditions.checkNotNull(kind, "Row kind must not be null."); - this.fields = new Object[arity]; + this.fieldByPosition = new Object[arity]; + this.fieldByName = null; + this.positionByName = null; } /** - * Create a new row instance. + * Creates a fixed-length row in position-based field mode. * - *

By default, a row describes an {@link RowKind#INSERT} change. + *

The semantics are equivalent to {@link Row#withPositions(int)}. This constructor exists + * for backwards compatibility. * - * @param arity The number of fields in the row. + * @param arity the number of fields in the row */ public Row(int arity) { this(RowKind.INSERT, arity); } + /** + * Creates a fixed-length row in position-based field mode. + * + *

Fields can be accessed by position via {@link #setField(int, Object)} and {@link + * #getField(int)}. + * + *

See the class documentation of {@link Row} for more information. + * + * @param kind kind of change a row describes in a changelog + * @param arity the number of fields in the row + * @return a new row instance + */ + public static Row withPositions(RowKind kind, int arity) { + return new Row(kind, new Object[arity], null, null); + } + + /** + * Creates a fixed-length row in position-based field mode. + * + *

Fields can be accessed by position via {@link #setField(int, Object)} and {@link + * #getField(int)}. + * + *

By default, a row describes an {@link RowKind#INSERT} change. + * + *

See the class documentation of {@link Row} for more information. + * + * @param arity the number of fields in the row + * @return a new row instance + */ + public static Row withPositions(int arity) { + return withPositions(RowKind.INSERT, arity); + } + + /** + * Creates a variable-length row in name-based field mode. + * + *

Fields can be accessed by name via {@link #setField(String, Object)} and {@link + * #getField(String)}. + * + *

See the class documentation of {@link Row} for more information. + * + * @param kind kind of change a row describes in a changelog + * @return a new row instance + */ + public static Row withNames(RowKind kind) { + return new Row(kind, null, new HashMap<>(), null); + } + + /** + * Creates a variable-length row in name-based field mode. + * + *

Fields can be accessed by name via {@link #setField(String, Object)} and {@link + * #getField(String)}. + * + *

By default, a row describes an {@link RowKind#INSERT} change. + * + *

See the class documentation of {@link Row} for more information. + * + * @return a new row instance + */ + public static Row withNames() { + return withNames(RowKind.INSERT); + } + /** * Returns the kind of change that this row describes in a changelog. * @@ -114,42 +231,166 @@ public final class Row implements Serializable { * *

Note: The row kind is kept separate from the fields and is not included in this number. * - * @return The number of fields in the row. + * @return the number of fields in the row */ public int getArity() { - return fields.length; + if (fieldByPosition != null) { + return fieldByPosition.length; + } else { + assert fieldByName != null; + return fieldByName.size(); + } } /** - * Returns the field's content at the specified position. + * Returns the field's content at the specified field position. + * + *

Note: The row must operate in position-based field mode. * - * @param pos The position of the field, 0-based. - * @return The field's content at the specified position. + * @param pos the position of the field, 0-based + * @return the field's content at the specified position */ public @Nullable Object getField(int pos) { - return fields[pos]; + if (fieldByPosition != null) { + return fieldByPosition[pos]; + } else { + throw new IllegalArgumentException( + "Accessing a field by position is not supported in name-based field mode."); + } + } + + /** + * Returns the field's content at the specified field position. + * + *

Note: The row must operate in position-based field mode. + * + *

This method avoids a lot of manual casting in the user implementation. + * + * @param pos the position of the field, 0-based + * @return the field's content at the specified position + */ + @SuppressWarnings("unchecked") + public T getFieldAs(int pos) { + return (T) getField(pos); + } + + /** + * Returns the field's content using the specified field name. + * + *

Note: The row must operate in name-based field mode. + * + * @param name the name of the field or null if not set previously + * @return the field's content + */ + public @Nullable Object getField(String name) { + if (fieldByName != null) { + return fieldByName.get(name); + } else if (positionByName != null) { + final Integer pos = positionByName.get(name); + if (pos == null) { + throw new IllegalArgumentException( + String.format("Unknown field name '%s' for mapping to a position.", name)); + } + assert fieldByPosition != null; + return fieldByPosition[pos]; + } else { + throw new IllegalArgumentException( + "Accessing a field by name is not supported in position-based field mode."); + } + } + + /** + * Returns the field's content using the specified field name. + * + *

Note: The row must operate in name-based field mode. + * + *

This method avoids a lot of manual casting in the user implementation. + * + * @param name the name of the field, set previously + * @return the field's content + */ + @SuppressWarnings("unchecked") + public T getFieldAs(String name) { + return (T) getField(name); } /** * Sets the field's content at the specified position. * - * @param pos The position of the field, 0-based. - * @param value The value to be assigned to the field at the specified position. + *

Note: The row must operate in position-based field mode. + * + * @param pos the position of the field, 0-based + * @param value the value to be assigned to the field at the specified position */ public void setField(int pos, @Nullable Object value) { - fields[pos] = value; + if (fieldByPosition != null) { + fieldByPosition[pos] = value; + } else { + throw new IllegalArgumentException( + "Accessing a field by position is not supported in name-based field mode."); + } } - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - for (int i = 0; i < fields.length; i++) { - if (i > 0) { - sb.append(","); + /** + * Sets the field's content using the specified field name. + * + *

Note: The row must operate in name-based field mode. + * + * @param name the name of the field + * @param value the value to be assigned to the field + */ + public void setField(String name, @Nullable Object value) { + if (fieldByName != null) { + fieldByName.put(name, value); + } else if (positionByName != null) { + final Integer pos = positionByName.get(name); + if (pos == null) { + throw new IllegalArgumentException( + String.format( + "Unknown field name '%s' for mapping to a row position. " + + "Available names are: %s", + name, positionByName.keySet())); } - sb.append(StringUtils.arrayAwareToString(fields[i])); + assert fieldByPosition != null; + fieldByPosition[pos] = value; + } else { + throw new IllegalArgumentException( + "Accessing a field by name is not supported in position-based field mode."); } - return sb.toString(); + } + + /** + * Returns the set of field names if this row operates in name-based field mode, otherwise null. + * + *

This method is a helper method for serializers and converters but can also be useful for + * other row transformations. + * + * @param includeNamedPositions whether or not to include named positions when this row operates + * in a hybrid field mode + */ + public @Nullable Set getFieldNames(boolean includeNamedPositions) { + if (fieldByName != null) { + return fieldByName.keySet(); + } + if (includeNamedPositions && positionByName != null) { + return positionByName.keySet(); + } + return null; + } + + /** Clears all fields of this row. */ + public void clear() { + if (fieldByPosition != null) { + Arrays.fill(fieldByPosition, null); + } else { + assert fieldByName != null; + fieldByName.clear(); + } + } + + @Override + public String toString() { + return RowUtils.deepToStringRow(kind, fieldByPosition, fieldByName); } @Override @@ -160,13 +401,21 @@ public final class Row implements Serializable { if (o == null || getClass() != o.getClass()) { return false; } - final Row row = (Row) o; - return deepEqualsRow(this, row); + final Row other = (Row) o; + return deepEqualsRow( + kind, + fieldByPosition, + fieldByName, + positionByName, + other.kind, + other.fieldByPosition, + other.fieldByName, + other.positionByName); } @Override public int hashCode() { - return deepHashCodeRow(this); + return deepHashCodeRow(kind, fieldByPosition, fieldByName); } // -------------------------------------------------------------------------------------------- @@ -174,8 +423,10 @@ public final class Row implements Serializable { // -------------------------------------------------------------------------------------------- /** - * Creates a new row and assigns the given values to the row's fields. This is more convenient - * than using the constructor. + * Creates a fixed-length row in position-based field mode and assigns the given values to the + * row's fields. + * + *

This method should be more convenient than {@link Row#withPositions(int)} in many cases. * *

For example: * @@ -186,7 +437,7 @@ public final class Row implements Serializable { * instead of * *

-     *     Row row = new Row(3);
+     *     Row row = Row.withPositions(3);
      *     row.setField(0, "hello");
      *     row.setField(1, true);
      *     row.setField(2, 1L);
@@ -195,7 +446,7 @@ public final class Row implements Serializable {
      * 

By default, a row describes an {@link RowKind#INSERT} change. */ public static Row of(Object... values) { - Row row = new Row(values.length); + final Row row = new Row(values.length); for (int i = 0; i < values.length; i++) { row.setField(i, values[i]); } @@ -203,8 +454,11 @@ public final class Row implements Serializable { } /** - * Creates a new row with given kind and assigns the given values to the row's fields. This is - * more convenient than using the constructor. + * Creates a fixed-length row in position-based field mode with given kind and assigns the given + * values to the row's fields. + * + *

This method should be more convenient than {@link Row#withPositions(RowKind, int)} in many + * cases. * *

For example: * @@ -215,15 +469,14 @@ public final class Row implements Serializable { * instead of * *

-     *     Row row = new Row(3);
-     *     row.setKind(RowKind.INSERT);
+     *     Row row = Row.withPositions(RowKind.INSERT, 3);
      *     row.setField(0, "hello");
      *     row.setField(1, true);
      *     row.setField(2, 1L);
      * 
*/ public static Row ofKind(RowKind kind, Object... values) { - Row row = new Row(kind, values.length); + final Row row = new Row(kind, values.length); for (int i = 0; i < values.length; i++) { row.setField(i, values[i]); } @@ -233,11 +486,42 @@ public final class Row implements Serializable { /** * Creates a new row which is copied from another row (including its {@link RowKind}). * - *

This method does not perform a deep copy. + *

This method does not perform a deep copy. Use {@link RowSerializer#copy(Row)} if required. */ public static Row copy(Row row) { - final Row newRow = new Row(row.kind, row.fields.length); - System.arraycopy(row.fields, 0, newRow.fields, 0, row.fields.length); + final Object[] newFieldByPosition; + if (row.fieldByPosition != null) { + newFieldByPosition = new Object[row.fieldByPosition.length]; + System.arraycopy( + row.fieldByPosition, 0, newFieldByPosition, 0, newFieldByPosition.length); + } else { + newFieldByPosition = null; + } + + final Map newFieldByName; + if (row.fieldByName != null) { + newFieldByName = new HashMap<>(row.fieldByName); + } else { + newFieldByName = null; + } + + return new Row(row.kind, newFieldByPosition, newFieldByName, row.positionByName); + } + + /** + * Creates a new row with projected fields and identical {@link RowKind} from another row. + * + *

This method does not perform a deep copy. + * + *

Note: The row must operate in position-based field mode. Field names are not projected. + * + * @param fieldPositions field indices to be projected + */ + public static Row project(Row row, int[] fieldPositions) { + final Row newRow = Row.withPositions(row.kind, fieldPositions.length); + for (int i = 0; i < fieldPositions.length; i++) { + newRow.setField(i, row.getField(fieldPositions[i])); + } return newRow; } @@ -246,12 +530,14 @@ public final class Row implements Serializable { * *

This method does not perform a deep copy. * - * @param fields field indices to be projected + *

Note: The row must operate in name-based field mode. + * + * @param fieldNames field names to be projected */ - public static Row project(Row row, int[] fields) { - final Row newRow = new Row(row.kind, fields.length); - for (int i = 0; i < fields.length; i++) { - newRow.fields[i] = row.fields[fields[i]]; + public static Row project(Row row, String[] fieldNames) { + final Row newRow = Row.withNames(row.getKind()); + for (String fieldName : fieldNames) { + newRow.setField(fieldName, row.getField(fieldName)); } return newRow; } @@ -262,24 +548,44 @@ public final class Row implements Serializable { * RowKind} of the result. * *

This method does not perform a deep copy. + * + *

Note: All rows must operate in position-based field mode. */ public static Row join(Row first, Row... remainings) { - int newLength = first.fields.length; + Preconditions.checkArgument( + first.fieldByPosition != null, + "All rows must operate in position-based field mode."); + int newLength = first.fieldByPosition.length; for (Row remaining : remainings) { - newLength += remaining.fields.length; + Preconditions.checkArgument( + remaining.fieldByPosition != null, + "All rows must operate in position-based field mode."); + newLength += remaining.fieldByPosition.length; } final Row joinedRow = new Row(first.kind, newLength); int index = 0; // copy the first row - System.arraycopy(first.fields, 0, joinedRow.fields, index, first.fields.length); - index += first.fields.length; + assert joinedRow.fieldByPosition != null; + System.arraycopy( + first.fieldByPosition, + 0, + joinedRow.fieldByPosition, + index, + first.fieldByPosition.length); + index += first.fieldByPosition.length; // copy the remaining rows for (Row remaining : remainings) { - System.arraycopy(remaining.fields, 0, joinedRow.fields, index, remaining.fields.length); - index += remaining.fields.length; + assert remaining.fieldByPosition != null; + System.arraycopy( + remaining.fieldByPosition, + 0, + joinedRow.fieldByPosition, + index, + remaining.fieldByPosition.length); + index += remaining.fieldByPosition.length; } return joinedRow; diff --git a/flink-core/src/main/java/org/apache/flink/types/RowUtils.java b/flink-core/src/main/java/org/apache/flink/types/RowUtils.java index 64fdbffe638..7cdd76d9267 100644 --- a/flink-core/src/main/java/org/apache/flink/types/RowUtils.java +++ b/flink-core/src/main/java/org/apache/flink/types/RowUtils.java @@ -18,10 +18,15 @@ package org.apache.flink.types; +import org.apache.flink.annotation.Internal; import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.util.StringUtils; + +import javax.annotation.Nullable; import java.util.Arrays; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -65,6 +70,17 @@ public final class RowUtils { } } + // -------------------------------------------------------------------------------------------- + // Internal utilities + // -------------------------------------------------------------------------------------------- + + /** Internal utility for creating a row in static named-position field mode. */ + @Internal + public static Row createRowWithNamedPositions( + RowKind kind, Object[] fieldByPosition, LinkedHashMap positionByName) { + return new Row(kind, fieldByPosition, null, positionByName); + } + // -------------------------------------------------------------------------------------------- // Default scoped for Row class only // -------------------------------------------------------------------------------------------- @@ -73,46 +89,103 @@ public final class RowUtils { * Compares two objects with proper (nested) equality semantics. This method supports all * external and most internal conversion classes of the table ecosystem. */ - static boolean deepEqualsRow(Row row1, Row row2) { - if (row1.getKind() != row2.getKind()) { + static boolean deepEqualsRow( + RowKind kind1, + @Nullable Object[] fieldByPosition1, + @Nullable Map fieldByName1, + @Nullable LinkedHashMap positionByName1, + RowKind kind2, + @Nullable Object[] fieldByPosition2, + @Nullable Map fieldByName2, + @Nullable LinkedHashMap positionByName2) { + if (kind1 != kind2) { return false; } - if (row1.getArity() != row2.getArity()) { - return false; + // positioned == positioned + else if (fieldByPosition1 != null && fieldByPosition2 != null) { + // positionByName is not included + return deepEqualsInternal(fieldByPosition1, fieldByPosition2); } - for (int pos = 0; pos < row1.getArity(); pos++) { - final Object f1 = row1.getField(pos); - final Object f2 = row2.getField(pos); - if (!deepEqualsInternal(f1, f2)) { - return false; - } + // named == named + else if (fieldByName1 != null && fieldByName2 != null) { + return deepEqualsInternal(fieldByName1, fieldByName2); } - return true; + // named positioned == named + else if (positionByName1 != null && fieldByName2 != null) { + return deepEqualsNamedRows(fieldByPosition1, positionByName1, fieldByName2); + } + // named == named positioned + else if (positionByName2 != null && fieldByName1 != null) { + return deepEqualsNamedRows(fieldByPosition2, positionByName2, fieldByName1); + } + return false; } /** * Hashes two objects with proper (nested) equality semantics. This method supports all external * and most internal conversion classes of the table ecosystem. */ - static int deepHashCodeRow(Row row) { - int result = row.getKind().toByteValue(); // for stable hash across JVM instances - for (int i = 0; i < row.getArity(); i++) { - result = 31 * result + deepHashCodeInternal(row.getField(i)); + static int deepHashCodeRow( + RowKind kind, + @Nullable Object[] fieldByPosition, + @Nullable Map fieldByName) { + int result = kind.toByteValue(); // for stable hash across JVM instances + if (fieldByPosition != null) { + // positionByName is not included + result = 31 * result + deepHashCodeInternal(fieldByPosition); + } else { + result = 31 * result + deepHashCodeInternal(fieldByName); } return result; } + /** + * Converts a row to a string representation. This method supports all external and most + * internal conversion classes of the table ecosystem. + */ + static String deepToStringRow( + RowKind kind, + @Nullable Object[] fieldByPosition, + @Nullable Map fieldByName) { + final StringBuilder sb = new StringBuilder(); + if (fieldByPosition != null) { + // TODO enable this for FLINK-18090 + // sb.append(kind.shortString()); + // deepToStringArray(sb, fieldByPosition); + deepToStringArrayLegacy(sb, fieldByPosition); + } else { + assert fieldByName != null; + sb.append(kind.shortString()); + deepToStringMap(sb, fieldByName); + } + return sb.toString(); + } + // -------------------------------------------------------------------------------------------- - // Internal utilities + // Helper methods // -------------------------------------------------------------------------------------------- + private static boolean deepEqualsNamedRows( + Object[] fieldByPosition1, + LinkedHashMap positionByName1, + Map fieldByName2) { + for (Map.Entry entry : fieldByName2.entrySet()) { + final Integer pos = positionByName1.get(entry.getKey()); + if (pos == null) { + return false; + } + if (!deepEqualsInternal(fieldByPosition1[pos], entry.getValue())) { + return false; + } + } + return true; + } + private static boolean deepEqualsInternal(Object o1, Object o2) { if (o1 == o2) { return true; } else if (o1 == null || o2 == null) { return false; - } else if (o1 instanceof Row && o2 instanceof Row) { - return deepEqualsRow((Row) o1, (Row) o2); } else if (o1 instanceof Object[] && o2 instanceof Object[]) { return deepEqualsArray((Object[]) o1, (Object[]) o2); } else if (o1 instanceof Map && o2 instanceof Map) { @@ -205,9 +278,7 @@ public final class RowUtils { if (o == null) { return 0; } - if (o instanceof Row) { - return deepHashCodeRow((Row) o); - } else if (o instanceof Object[]) { + if (o instanceof Object[]) { return deepHashCodeArray((Object[]) o); } else if (o instanceof Map) { return deepHashCodeMap((Map) o); @@ -241,6 +312,71 @@ public final class RowUtils { return result; } + private static void deepToStringInternal(StringBuilder sb, Object o) { + if (o instanceof Object[]) { + deepToStringArray(sb, (Object[]) o); + } else if (o instanceof Map) { + deepToStringMap(sb, (Map) o); + } else if (o instanceof List) { + deepToStringList(sb, (List) o); + } else { + sb.append(StringUtils.arrayAwareToString(o)); + } + } + + private static void deepToStringArray(StringBuilder sb, Object[] a) { + sb.append('['); + boolean isFirst = true; + for (Object o : a) { + if (isFirst) { + isFirst = false; + } else { + sb.append(", "); + } + deepToStringInternal(sb, o); + } + sb.append(']'); + } + + private static void deepToStringArrayLegacy(StringBuilder sb, Object[] a) { + for (int i = 0; i < a.length; i++) { + if (i > 0) { + sb.append(","); + } + sb.append(StringUtils.arrayAwareToString(a[i])); + } + } + + private static void deepToStringMap(StringBuilder sb, Map m) { + sb.append('{'); + boolean isFirst = true; + for (Map.Entry entry : m.entrySet()) { + if (isFirst) { + isFirst = false; + } else { + sb.append(", "); + } + deepToStringInternal(sb, entry.getKey()); + sb.append('='); + deepToStringInternal(sb, entry.getValue()); + } + sb.append('}'); + } + + private static void deepToStringList(StringBuilder sb, List l) { + sb.append('['); + boolean isFirst = true; + for (E element : l) { + if (isFirst) { + isFirst = false; + } else { + sb.append(", "); + } + deepToStringInternal(sb, element); + } + sb.append(']'); + } + private RowUtils() { // no instantiation } diff --git a/flink-core/src/test/java/org/apache/flink/api/java/typeutils/runtime/RowSerializerTest.java b/flink-core/src/test/java/org/apache/flink/api/java/typeutils/runtime/RowSerializerTest.java index ddbcb71d1b4..41026539bc9 100644 --- a/flink-core/src/test/java/org/apache/flink/api/java/typeutils/runtime/RowSerializerTest.java +++ b/flink-core/src/test/java/org/apache/flink/api/java/typeutils/runtime/RowSerializerTest.java @@ -20,6 +20,7 @@ package org.apache.flink.api.java.typeutils.runtime; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.common.typeutils.SerializerTestInstance; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple3; @@ -28,30 +29,61 @@ import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.types.Row; import org.apache.flink.types.RowKind; +import org.apache.flink.types.RowUtils; import org.junit.Test; import java.io.Serializable; +import java.util.LinkedHashMap; import java.util.Objects; public class RowSerializerTest { @Test public void testRowSerializer() { - TypeInformation typeInfo = - new RowTypeInfo(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO); - Row row1 = new Row(2); - row1.setKind(RowKind.UPDATE_BEFORE); - row1.setField(0, 1); - row1.setField(1, "a"); + final TypeInformation rowTypeInfo = + Types.ROW_NAMED( + new String[] {"a", "b", "c", "d"}, + Types.INT, + Types.STRING, + Types.DOUBLE, + Types.BOOLEAN); - Row row2 = new Row(2); - row2.setKind(RowKind.INSERT); - row2.setField(0, 2); - row2.setField(1, null); + final Row positionedRow = Row.withPositions(RowKind.UPDATE_BEFORE, 4); + positionedRow.setKind(RowKind.UPDATE_BEFORE); + positionedRow.setField(0, 1); + positionedRow.setField(1, "a"); + positionedRow.setField(2, null); + positionedRow.setField(3, false); - TypeSerializer serializer = typeInfo.createSerializer(new ExecutionConfig()); - RowSerializerTestInstance instance = new RowSerializerTestInstance(serializer, row1, row2); + final Row namedRow = Row.withNames(RowKind.UPDATE_BEFORE); + namedRow.setField("a", 1); + namedRow.setField("b", "a"); + namedRow.setField("c", null); + namedRow.setField("d", false); + + final Row sparseNamedRow = Row.withNames(RowKind.UPDATE_BEFORE); + namedRow.setField("a", 1); + namedRow.setField("b", "a"); + namedRow.setField("d", false); // "c" is missing + + final LinkedHashMap positionByName = new LinkedHashMap<>(); + positionByName.put("a", 0); + positionByName.put("b", 1); + positionByName.put("c", 2); + positionByName.put("d", 3); + final Row namedPositionedRow = + RowUtils.createRowWithNamedPositions( + RowKind.UPDATE_BEFORE, new Object[4], positionByName); + namedPositionedRow.setField("a", 1); + namedPositionedRow.setField(1, "a"); + namedPositionedRow.setField(2, null); + namedPositionedRow.setField("d", false); + + final TypeSerializer serializer = rowTypeInfo.createSerializer(new ExecutionConfig()); + final RowSerializerTestInstance instance = + new RowSerializerTestInstance( + serializer, positionedRow, namedRow, sparseNamedRow, namedPositionedRow); instance.testAll(); } diff --git a/flink-core/src/test/java/org/apache/flink/testutils/DeeplyEqualsChecker.java b/flink-core/src/test/java/org/apache/flink/testutils/DeeplyEqualsChecker.java index 4aa48fdd2e3..61a19260903 100644 --- a/flink-core/src/test/java/org/apache/flink/testutils/DeeplyEqualsChecker.java +++ b/flink-core/src/test/java/org/apache/flink/testutils/DeeplyEqualsChecker.java @@ -20,7 +20,6 @@ package org.apache.flink.testutils; import org.apache.flink.api.java.tuple.Tuple; import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.types.Row; import java.lang.reflect.Array; import java.util.ArrayList; @@ -35,7 +34,6 @@ import java.util.function.BiFunction; *

    *
  • {@link Tuple}s *
  • Java arrays - *
  • {@link Row} *
  • {@link Throwable} *
* @@ -93,8 +91,6 @@ public class DeeplyEqualsChecker { return deepEqualsArray(e1, e2); } else if (e1 instanceof Tuple && e2 instanceof Tuple) { return deepEqualsTuple((Tuple) e1, (Tuple) e2); - } else if (e1 instanceof Row && e2 instanceof Row) { - return deepEqualsRow((Row) e1, (Row) e2); } else if (e1 instanceof Throwable && e2 instanceof Throwable) { return ((Throwable) e1).getMessage().equals(((Throwable) e2).getMessage()); } else { @@ -138,22 +134,4 @@ public class DeeplyEqualsChecker { return true; } - - private boolean deepEqualsRow(Row row1, Row row2) { - int arity = row1.getArity(); - - if (row1.getArity() != row2.getArity()) { - return false; - } - - for (int i = 0; i < arity; i++) { - Object copiedValue = row1.getField(i); - Object element = row2.getField(i); - if (!deepEquals(copiedValue, element)) { - return false; - } - } - - return true; - } } diff --git a/flink-core/src/test/java/org/apache/flink/types/RowTest.java b/flink-core/src/test/java/org/apache/flink/types/RowTest.java index 807f88f12d0..f72e947e506 100644 --- a/flink-core/src/test/java/org/apache/flink/types/RowTest.java +++ b/flink-core/src/test/java/org/apache/flink/types/RowTest.java @@ -24,71 +24,314 @@ import org.junit.Test; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.Map; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; +import static org.junit.internal.matchers.ThrowableMessageMatcher.hasMessage; +/** Tests for {@link Row} and {@link RowUtils}. */ public class RowTest { + @Test - public void testRowToString() { - Row row = new Row(5); - row.setField(0, 1); - row.setField(1, "hello"); + public void testRowNamed() { + final Row row = Row.withNames(RowKind.DELETE); + + // test getters and setters + row.setField("a", 42); + row.setField("b", true); + row.setField("c", null); + assertThat(row.getFieldNames(false), containsInAnyOrder("a", "b", "c")); + assertThat(row.getArity(), equalTo(3)); + assertThat(row.getKind(), equalTo(RowKind.DELETE)); + assertThat(row.getField("a"), equalTo(42)); + assertThat(row.getField("b"), equalTo(true)); + assertThat(row.getField("c"), equalTo(null)); + + // test toString + assertThat(row.toString(), equalTo("-D{a=42, b=true, c=null}")); + + // test override + row.setField("a", 13); + row.setField("c", "Hello"); + assertThat(row.getField("a"), equalTo(13)); + assertThat(row.getField("b"), equalTo(true)); + assertThat(row.getField("c"), equalTo("Hello")); + + // test equality + final Row otherRow1 = Row.withNames(RowKind.DELETE); + otherRow1.setField("a", 13); + otherRow1.setField("b", true); + otherRow1.setField("c", "Hello"); + assertThat(row.hashCode(), equalTo(otherRow1.hashCode())); + assertThat(row, equalTo(otherRow1)); + + // test inequality + final Row otherRow2 = Row.withNames(RowKind.DELETE); + otherRow2.setField("a", 13); + otherRow2.setField("b", false); // diff here + otherRow2.setField("c", "Hello"); + assertThat(row.hashCode(), not(equalTo(otherRow2.hashCode()))); + assertThat(row, not(equalTo(otherRow2))); + + // test clear + row.clear(); + assertThat(row.getArity(), equalTo(0)); + assertThat(row.getFieldNames(false), empty()); + assertThat(row.toString(), equalTo("-D{}")); + + // test invalid setter + try { + row.setField(0, 13); + fail(); + } catch (Throwable t) { + assertThat(t, hasMessage(containsString("not supported in name-based field mode"))); + } + + // test invalid getter + try { + assertNull(row.getField(0)); + fail(); + } catch (Throwable t) { + assertThat(t, hasMessage(containsString("not supported in name-based field mode"))); + } + } + + @Test + public void testRowPositioned() { + final Row row = Row.withPositions(RowKind.DELETE, 3); + + // test getters and setters + row.setField(0, 42); + row.setField(1, true); row.setField(2, null); - row.setField(3, new Tuple2<>(2, "hi")); - row.setField(4, "hello world"); + assertThat(row.getFieldNames(false), equalTo(null)); + assertThat(row.getArity(), equalTo(3)); + assertThat(row.getKind(), equalTo(RowKind.DELETE)); + assertThat(row.getField(0), equalTo(42)); + assertThat(row.getField(1), equalTo(true)); + assertThat(row.getField(2), equalTo(null)); + + // test toString + // TODO enable this for FLINK-18090 + // assertThat(row.toString(), equalTo("-D[42, true, null]")); + assertThat(row.toString(), equalTo("42,true,null")); + + // test override + row.setField(0, 13); + row.setField(2, "Hello"); + assertThat(row.getField(0), equalTo(13)); + assertThat(row.getField(1), equalTo(true)); + assertThat(row.getField(2), equalTo("Hello")); + + // test equality + final Row otherRow1 = Row.withPositions(RowKind.DELETE, 3); + otherRow1.setField(0, 13); + otherRow1.setField(1, true); + otherRow1.setField(2, "Hello"); + assertThat(row.hashCode(), equalTo(otherRow1.hashCode())); + assertThat(row, equalTo(otherRow1)); + + // test inequality + final Row otherRow2 = Row.withPositions(RowKind.DELETE, 3); + otherRow2.setField(0, 13); + otherRow2.setField(1, false); // diff here + otherRow2.setField(2, "Hello"); + assertThat(row.hashCode(), not(equalTo(otherRow2.hashCode()))); + assertThat(row, not(equalTo(otherRow2))); + + // test clear + row.clear(); + assertThat(row.getArity(), equalTo(3)); + assertThat(row.getFieldNames(false), equalTo(null)); + // TODO enable this for FLINK-18090 + // assertThat(row.toString(), equalTo("-D[null, null, null]")); + assertThat(row.toString(), equalTo("null,null,null")); + + // test invalid setter + try { + row.setField("a", 13); + fail(); + } catch (Throwable t) { + assertThat(t, hasMessage(containsString("not supported in position-based field mode"))); + } - assertEquals("1,hello,null,(2,hi),hello world", row.toString()); + // test invalid getter + try { + assertNull(row.getField("a")); + fail(); + } catch (Throwable t) { + assertThat(t, hasMessage(containsString("not supported in position-based field mode"))); + } + } + + @Test + public void testRowNamedPositioned() { + final LinkedHashMap positionByName = new LinkedHashMap<>(); + positionByName.put("a", 0); + positionByName.put("b", 1); + positionByName.put("c", 2); + final Row row = + RowUtils.createRowWithNamedPositions(RowKind.DELETE, new Object[3], positionByName); + + // test getters and setters + row.setField(0, 42); + row.setField("b", true); + row.setField(2, null); + assertThat(row.getFieldNames(false), equalTo(null)); + assertThat(row.getFieldNames(true), contains("a", "b", "c")); + assertThat(row.getArity(), equalTo(3)); + assertThat(row.getKind(), equalTo(RowKind.DELETE)); + assertThat(row.getField(0), equalTo(42)); + assertThat(row.getField(1), equalTo(true)); + assertThat(row.getField("c"), equalTo(null)); + + // test toString + // TODO enable this for FLINK-18090 + // assertThat(row.toString(), equalTo("-D[42, true, null]")); + assertThat(row.toString(), equalTo("42,true,null")); + + // test override + row.setField("a", 13); + row.setField(2, "Hello"); + assertThat(row.getField(0), equalTo(13)); + assertThat(row.getField("b"), equalTo(true)); + assertThat(row.getField(2), equalTo("Hello")); + + // test equality + final Row otherRow1 = Row.withPositions(RowKind.DELETE, 3); + otherRow1.setField(0, 13); + otherRow1.setField(1, true); + otherRow1.setField(2, "Hello"); + assertThat(row.hashCode(), equalTo(otherRow1.hashCode())); + assertThat(row, equalTo(otherRow1)); + + // test inequality + final Row otherRow2 = Row.withPositions(RowKind.DELETE, 3); + otherRow2.setField(0, 13); + otherRow2.setField(1, false); // diff here + otherRow2.setField(2, "Hello"); + assertThat(row.hashCode(), not(equalTo(otherRow2.hashCode()))); + assertThat(row, not(equalTo(otherRow2))); + + // test clear + row.clear(); + assertThat(row.getArity(), equalTo(3)); + assertThat(row.getFieldNames(true), contains("a", "b", "c")); + // TODO enable this for FLINK-18090 + // assertThat(row.toString(), equalTo("-D[null, null, null]")); + assertThat(row.toString(), equalTo("null,null,null")); + + // test invalid setter + try { + row.setField("DOES_NOT_EXIST", 13); + fail(); + } catch (Throwable t) { + assertThat(t, hasMessage(containsString("Unknown field name 'DOES_NOT_EXIST'"))); + } + + // test invalid getter + try { + assertNull(row.getField("DOES_NOT_EXIST")); + fail(); + } catch (Throwable t) { + assertThat(t, hasMessage(containsString("Unknown field name 'DOES_NOT_EXIST'"))); + } } @Test public void testRowOf() { - Row row1 = Row.of(1, "hello", null, Tuple2.of(2L, "hi"), true); - Row row2 = new Row(5); + final Row row1 = Row.of(1, "hello", null, Tuple2.of(2L, "hi"), true); + + final Row row2 = Row.withPositions(5); row2.setField(0, 1); row2.setField(1, "hello"); row2.setField(2, null); row2.setField(3, new Tuple2<>(2L, "hi")); row2.setField(4, true); + assertEquals(row1, row2); } @Test - public void testRowCopy() { - Row row = new Row(5); + public void testRowCopyPositioned() { + final Row row = Row.withPositions(5); row.setField(0, 1); row.setField(1, "hello"); row.setField(2, null); row.setField(3, new Tuple2<>(2, "hi")); row.setField(4, "hello world"); - Row copy = Row.copy(row); + final Row copy = Row.copy(row); assertEquals(row, copy); assertNotSame(row, copy); } @Test - public void testRowProject() { - Row row = new Row(5); + public void testRowCopyNamed() { + final Row row = Row.withNames(); + row.setField("a", 1); + row.setField("b", "hello"); + row.setField("c", null); + row.setField("d", new Tuple2<>(2, "hi")); + row.setField("e", "hello world"); + + final Row copy = Row.copy(row); + assertEquals(row, copy); + assertNotSame(row, copy); + } + + @Test + public void testRowProjectPositioned() { + final Row row = Row.withPositions(5); row.setField(0, 1); row.setField(1, "hello"); row.setField(2, null); row.setField(3, new Tuple2<>(2, "hi")); row.setField(4, "hello world"); - Row projected = Row.project(row, new int[] {0, 2, 4}); + final Row projected = Row.project(row, new int[] {0, 2, 4}); - Row expected = new Row(3); + final Row expected = Row.withPositions(3); expected.setField(0, 1); expected.setField(1, null); expected.setField(2, "hello world"); + assertEquals(expected, projected); } @Test - public void testRowJoin() { + public void testRowProjectNamed() { + final Row row = Row.withNames(); + row.setField("a", 1); + row.setField("b", "hello"); + row.setField("c", null); + row.setField("d", new Tuple2<>(2, "hi")); + row.setField("e", "hello world"); + + final Row projected = Row.project(row, new String[] {"a", "c", "e"}); + + final Row expected = Row.withNames(); + expected.setField("a", 1); + expected.setField("c", null); + expected.setField("e", "hello world"); + + assertEquals(expected, projected); + } + + @Test + public void testRowJoinPositioned() { Row row1 = new Row(2); row1.setField(0, 1); row1.setField(1, "hello"); @@ -112,7 +355,7 @@ public class RowTest { } @Test - public void testDeepEqualsAndHashCode() { + public void testDeepEqualsAndHashCodePositioned() { final Map originalMap = new HashMap<>(); originalMap.put("k1", new byte[] {1, 2, 3}); originalMap.put("k2", new byte[] {3, 4, 6}); @@ -203,4 +446,59 @@ public class RowTest { assertNotEquals(row.hashCode(), originalRow.hashCode()); } } + + @Test + public void testDeepEqualsCodeNamed() { + final Row named = Row.withNames(RowKind.DELETE); + named.setField("a", 12); // "b" is missing due to sparsity + named.setField("c", true); + + final LinkedHashMap positionByName = new LinkedHashMap<>(); + positionByName.put("a", 0); + positionByName.put("b", 1); + positionByName.put("c", 2); + final Row namedPositioned = + RowUtils.createRowWithNamedPositions(RowKind.DELETE, new Object[3], positionByName); + namedPositioned.setField("a", 12); + namedPositioned.setField("b", null); + namedPositioned.setField("c", true); + + assertThat(named, equalTo(namedPositioned)); + assertThat(namedPositioned, equalTo(named)); + + named.setField("b", "Hello"); + assertThat(named, not(equalTo(namedPositioned))); + assertThat(namedPositioned, not(equalTo(named))); + } + + @Test + public void testDeepToString() { + final Row row = Row.withNames(RowKind.UPDATE_BEFORE); + row.setField("a", 1); + row.setField("b", "hello"); + row.setField("c", null); + row.setField("d", new Tuple2<>(2, "hi")); + row.setField("e", "hello world"); + row.setField("f", new int[][] {{1}, null, {3, 4}}); + row.setField("g", new Boolean[][] {{true}, null, {false, false}}); + final Map map = new HashMap<>(); + map.put("a", new Integer[] {1, 2, 3, 4}); + map.put("b", new Integer[] {}); + map.put("c", null); + row.setField("h", map); + + assertThat( + row.toString(), + equalTo( + "-U{" + + "a=1, " + + "b=hello, " + + "c=null, " + + "d=(2,hi), " + + "e=hello world, " + + "f=[[1], null, [3, 4]], " + + "g=[[true], null, [false, false]], " + + "h={a=[1, 2, 3, 4], b=[], c=null}" + + "}")); + } } diff --git a/flink-python/src/test/java/org/apache/flink/streaming/api/utils/PythonTypeUtilsTest.java b/flink-python/src/test/java/org/apache/flink/streaming/api/utils/PythonTypeUtilsTest.java index d10feb65f0a..b24403cd013 100644 --- a/flink-python/src/test/java/org/apache/flink/streaming/api/utils/PythonTypeUtilsTest.java +++ b/flink-python/src/test/java/org/apache/flink/streaming/api/utils/PythonTypeUtilsTest.java @@ -181,7 +181,7 @@ public class PythonTypeUtilsTest { rowTypeInfo); assertEquals( convertedTypeSerializer, - new RowSerializer(new TypeSerializer[] {IntSerializer.INSTANCE}, false)); + new RowSerializer(new TypeSerializer[] {IntSerializer.INSTANCE}, null)); TupleTypeInfo tupleTypeInfo = (TupleTypeInfo) Types.TUPLE(Types.INT); convertedTypeSerializer = diff --git a/flink-python/src/test/java/org/apache/flink/table/runtime/arrow/sources/RowArrowSourceFunctionTest.java b/flink-python/src/test/java/org/apache/flink/table/runtime/arrow/sources/RowArrowSourceFunctionTest.java index cdd2202142c..329bcec85f9 100644 --- a/flink-python/src/test/java/org/apache/flink/table/runtime/arrow/sources/RowArrowSourceFunctionTest.java +++ b/flink-python/src/test/java/org/apache/flink/table/runtime/arrow/sources/RowArrowSourceFunctionTest.java @@ -51,7 +51,7 @@ public class RowArrowSourceFunctionTest extends ArrowSourceFunctionTestBase public RowArrowSourceFunctionTest() { super( VectorSchemaRoot.create(ArrowUtils.toArrowSchema(rowType), allocator), - new RowSerializer(new TypeSerializer[] {StringSerializer.INSTANCE}, false), + new RowSerializer(new TypeSerializer[] {StringSerializer.INSTANCE}), Comparator.comparing(o -> (String) (o.getField(0)))); } diff --git a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/RowFunctionITCase.java b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/RowFunctionITCase.java index e01c6d2a678..1b09ad2d69a 100644 --- a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/RowFunctionITCase.java +++ b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/RowFunctionITCase.java @@ -92,7 +92,7 @@ public class RowFunctionITCase extends BuiltInFunctionTestBase { public static class TakesRow extends ScalarFunction { public @DataTypeHint("ROW") Row eval( @DataTypeHint("ROW") Row row, Integer i) { - row.setField(0, (int) row.getField(0) + i); + row.setField("i", (int) row.getField("i") + i); return row; } } diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/TemporalTableFunctionJoinTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/TemporalTableFunctionJoinTest.xml index a3c6495aec1..89f7aa57bbc 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/TemporalTableFunctionJoinTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/table/TemporalTableFunctionJoinTest.xml @@ -25,7 +25,7 @@ LogicalJoin(condition=[=($3, $1)], joinType=[inner]) : +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{}]) : :- LogicalProject(o_rowtime=[AS($0, _UTF-16LE'o_rowtime')], o_comment=[AS($1, _UTF-16LE'o_comment')], o_amount=[AS($2, _UTF-16LE'o_amount')], o_currency=[AS($3, _UTF-16LE'o_currency')], o_secondary_key=[AS($4, _UTF-16LE'o_secondary_key')]) : : +- LogicalTableScan(table=[[default_catalog, default_database, Orders]]) -: +- LogicalTableFunctionScan(invocation=[org$apache$flink$table$functions$TemporalTableFunctionImpl$87edb49a77e545ed3da6318629a54024($0)], rowType=[RecordType(TIMESTAMP(3) rowtime, VARCHAR(2147483647) comment, VARCHAR(2147483647) currency, INTEGER rate, INTEGER secondary_key)], elementType=[class [Ljava.lang.Object;]) +: +- LogicalTableFunctionScan(invocation=[org$apache$flink$table$functions$TemporalTableFunctionImpl$11da53fa5c705009c63c1059f9311821($0)], rowType=[RecordType(TIMESTAMP(3) rowtime, VARCHAR(2147483647) comment, VARCHAR(2147483647) currency, INTEGER rate, INTEGER secondary_key)], elementType=[class [Ljava.lang.Object;]) +- LogicalTableScan(table=[[default_catalog, default_database, ThirdTable]]) ]]> @@ -53,7 +53,7 @@ LogicalProject(rate=[AS(*($0, $4), _UTF-16LE'rate')]) +- LogicalFilter(condition=[=($3, $1)]) +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{}]) :- LogicalTableScan(table=[[default_catalog, default_database, Orders]]) - +- LogicalTableFunctionScan(invocation=[org$apache$flink$table$functions$TemporalTableFunctionImpl$9f8cd64a3b5a060e7794a65524cd070a($2)], rowType=[RecordType(VARCHAR(2147483647) currency, INTEGER rate, TIME ATTRIBUTE(ROWTIME) rowtime)], elementType=[class [Ljava.lang.Object;]) + +- LogicalTableFunctionScan(invocation=[org$apache$flink$table$functions$TemporalTableFunctionImpl$7b64fb5334c47de8df24f3ab20394c19($2)], rowType=[RecordType(VARCHAR(2147483647) currency, INTEGER rate, TIME ATTRIBUTE(ROWTIME) rowtime)], elementType=[class [Ljava.lang.Object;]) ]]> @@ -95,7 +95,7 @@ LogicalProject(rate=[AS(*($0, $4), _UTF-16LE'rate')]) +- LogicalFilter(condition=[=($3, $1)]) +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{}]) :- LogicalTableScan(table=[[default_catalog, default_database, ProctimeOrders]]) - +- LogicalTableFunctionScan(invocation=[org$apache$flink$table$functions$TemporalTableFunctionImpl$5e34ee5bc251237bdaefabb25bfc63dc($2)], rowType=[RecordType(VARCHAR(2147483647) currency, INTEGER rate, TIME ATTRIBUTE(PROCTIME) proctime)], elementType=[class [Ljava.lang.Object;]) + +- LogicalTableFunctionScan(invocation=[org$apache$flink$table$functions$TemporalTableFunctionImpl$3f15d21732a0cca300e47b78077046df($2)], rowType=[RecordType(VARCHAR(2147483647) currency, INTEGER rate, TIME ATTRIBUTE(PROCTIME) proctime)], elementType=[class [Ljava.lang.Object;]) ]]> @@ -116,7 +116,7 @@ LogicalProject(rate=[AS(*($0, $4), _UTF-16LE'rate')]) +- LogicalFilter(condition=[=($3, $1)]) +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{}]) :- LogicalTableScan(table=[[default_catalog, default_database, Orders]]) - +- LogicalTableFunctionScan(invocation=[org$apache$flink$table$functions$TemporalTableFunctionImpl$9f8cd64a3b5a060e7794a65524cd070a($2)], rowType=[RecordType(VARCHAR(2147483647) currency, INTEGER rate, TIME ATTRIBUTE(ROWTIME) rowtime)], elementType=[class [Ljava.lang.Object;]) + +- LogicalTableFunctionScan(invocation=[org$apache$flink$table$functions$TemporalTableFunctionImpl$7b64fb5334c47de8df24f3ab20394c19($2)], rowType=[RecordType(VARCHAR(2147483647) currency, INTEGER rate, TIME ATTRIBUTE(ROWTIME) rowtime)], elementType=[class [Ljava.lang.Object;]) ]]> diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/data/conversion/RowRowConverter.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/data/conversion/RowRowConverter.java index f2af93b7ca2..794e6e9fdd4 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/data/conversion/RowRowConverter.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/data/conversion/RowRowConverter.java @@ -24,10 +24,15 @@ import org.apache.flink.table.data.RowData; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.logical.RowType; import org.apache.flink.types.Row; +import org.apache.flink.types.RowUtils; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Set; import java.util.stream.IntStream; +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getFieldNames; + /** Converter for {@link RowType} of {@link Row} external type. */ @Internal public class RowRowConverter implements DataStructureConverter { @@ -38,11 +43,15 @@ public class RowRowConverter implements DataStructureConverter { private final RowData.FieldGetter[] fieldGetters; + private final LinkedHashMap positionByName; + private RowRowConverter( DataStructureConverter[] fieldConverters, - RowData.FieldGetter[] fieldGetters) { + RowData.FieldGetter[] fieldGetters, + LinkedHashMap positionByName) { this.fieldConverters = fieldConverters; this.fieldGetters = fieldGetters; + this.positionByName = positionByName; } @Override @@ -56,22 +65,45 @@ public class RowRowConverter implements DataStructureConverter { public RowData toInternal(Row external) { final int length = fieldConverters.length; final GenericRowData genericRow = new GenericRowData(external.getKind(), length); - for (int pos = 0; pos < length; pos++) { - final Object value = external.getField(pos); - genericRow.setField(pos, fieldConverters[pos].toInternalOrNull(value)); + + final Set fieldNames = external.getFieldNames(false); + + // position-based field access + if (fieldNames == null) { + for (int pos = 0; pos < length; pos++) { + final Object value = external.getField(pos); + genericRow.setField(pos, fieldConverters[pos].toInternalOrNull(value)); + } + } + // name-based field access + else { + for (String fieldName : fieldNames) { + final Integer targetPos = positionByName.get(fieldName); + if (targetPos == null) { + throw new IllegalArgumentException( + String.format( + "Unknown field name '%s' for mapping to a row position. " + + "Available names are: %s", + fieldName, positionByName.keySet())); + } + final Object value = external.getField(fieldName); + genericRow.setField(targetPos, fieldConverters[targetPos].toInternalOrNull(value)); + } } + return genericRow; } @Override public Row toExternal(RowData internal) { final int length = fieldConverters.length; - final Row row = new Row(internal.getRowKind(), length); + final Object[] fieldByPosition = new Object[length]; for (int pos = 0; pos < length; pos++) { final Object value = fieldGetters[pos].getFieldOrNull(internal); - row.setField(pos, fieldConverters[pos].toExternalOrNull(value)); + fieldByPosition[pos] = fieldConverters[pos].toExternalOrNull(value); } - return row; + return RowUtils.createRowWithNamedPositions( + internal.getRowKind(), fieldByPosition, positionByName); } // -------------------------------------------------------------------------------------------- @@ -92,6 +124,11 @@ public class RowRowConverter implements DataStructureConverter { RowData.createFieldGetter( fields.get(pos).getLogicalType(), pos)) .toArray(RowData.FieldGetter[]::new); - return new RowRowConverter(fieldConverters, fieldGetters); + final List fieldNames = getFieldNames(dataType.getLogicalType()); + final LinkedHashMap positionByName = new LinkedHashMap<>(); + for (int i = 0; i < fieldNames.size(); i++) { + positionByName.put(fieldNames.get(i), i); + } + return new RowRowConverter(fieldConverters, fieldGetters, positionByName); } } diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/data/DataStructureConvertersTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/data/DataStructureConvertersTest.java index 7f34ab8fcf2..0780632282b 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/data/DataStructureConvertersTest.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/data/DataStructureConvertersTest.java @@ -237,6 +237,16 @@ public class DataStructureConvertersTest { FIELD("b_1", DOUBLE()), FIELD("b_2", BOOLEAN()))))) .convertedTo(Row.class, Row.ofKind(RowKind.DELETE, 12, Row.of(2.0, null))) + .convertedToSupplier( + Row.class, + () -> { + final Row namedRow = Row.withNames(RowKind.DELETE); + namedRow.setField("a", 12); + final Row sparseNamedRow = Row.withNames(); + sparseNamedRow.setField("b_1", 2.0); // "b_2" is omitted + namedRow.setField("b", sparseNamedRow); + return namedRow; + }) .convertedTo( RowData.class, GenericRowData.ofKind( -- GitLab