提交 28bcbd0c 编写于 作者: D Dian Fu 提交者: Hequn Cheng

[FLINK-16608][python] Support TimeType in vectorized Python UDF

上级 9992b377
......@@ -439,6 +439,15 @@ class ArrowCoder(DeterministicCoder):
field.type.nullable)
elif field.type.type_name == flink_fn_execution_pb2.Schema.TypeName.DATE:
return pa.field(field.name, pa.date32(), field.type.nullable)
elif field.type.type_name == flink_fn_execution_pb2.Schema.TypeName.TIME:
if field.type.time_info.precision == 0:
return pa.field(field.name, pa.time32('s'), field.type.nullable)
elif 1 <= field.type.time_type.precision <= 3:
return pa.field(field.name, pa.time32('ms'), field.type.nullable)
elif 4 <= field.type.time_type.precision <= 6:
return pa.field(field.name, pa.time64('us'), field.type.nullable)
else:
return pa.field(field.name, pa.time64('ns'), field.type.nullable)
else:
raise ValueError("field_type %s is not supported." % field.type)
......
......@@ -134,6 +134,12 @@ class PandasUDFITTests(object):
'date_param of wrong type %s !' % type(date_param[0])
return date_param
def time_func(time_param):
assert isinstance(time_param, pd.Series)
assert isinstance(time_param[0], datetime.time), \
'time_param of wrong type %s !' % type(time_param[0])
return time_param
self.t_env.register_function(
"tinyint_func",
udf(tinyint_func, [DataTypes.TINYINT()], DataTypes.TINYINT(), udf_type="pandas"))
......@@ -179,19 +185,23 @@ class PandasUDFITTests(object):
"date_func",
udf(date_func, [DataTypes.DATE()], DataTypes.DATE(), udf_type="pandas"))
self.t_env.register_function(
"time_func",
udf(time_func, [DataTypes.TIME()], DataTypes.TIME(), udf_type="pandas"))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n'],
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o'],
[DataTypes.TINYINT(), DataTypes.SMALLINT(), DataTypes.INT(), DataTypes.BIGINT(),
DataTypes.BOOLEAN(), DataTypes.BOOLEAN(), DataTypes.FLOAT(), DataTypes.DOUBLE(),
DataTypes.STRING(), DataTypes.STRING(), DataTypes.BYTES(), DataTypes.DECIMAL(38, 18),
DataTypes.DECIMAL(38, 18), DataTypes.DATE()])
DataTypes.DECIMAL(38, 18), DataTypes.DATE(), DataTypes.TIME()])
self.t_env.register_table_sink("Results", table_sink)
t = self.t_env.from_elements(
[(1, 32767, -2147483648, 1, True, False, 1.0, 1.0, 'hello', '中文',
bytearray(b'flink'), decimal.Decimal('1000000000000000000.05'),
decimal.Decimal('1000000000000000000.05999999999999999899999999999'),
datetime.date(2014, 9, 13))],
datetime.date(2014, 9, 13), datetime.time(hour=1, minute=0, second=1))],
DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.TINYINT()),
DataTypes.FIELD("b", DataTypes.SMALLINT()),
......@@ -206,7 +216,8 @@ class PandasUDFITTests(object):
DataTypes.FIELD("k", DataTypes.BYTES()),
DataTypes.FIELD("l", DataTypes.DECIMAL(38, 18)),
DataTypes.FIELD("m", DataTypes.DECIMAL(38, 18)),
DataTypes.FIELD("n", DataTypes.DATE())]))
DataTypes.FIELD("n", DataTypes.DATE()),
DataTypes.FIELD("o", DataTypes.TIME())]))
t.select("tinyint_func(a),"
"smallint_func(b),"
......@@ -221,14 +232,15 @@ class PandasUDFITTests(object):
"varbinary_func(k),"
"decimal_func(l),"
"decimal_func(m),"
"date_func(n)") \
"date_func(n),"
"time_func(o)") \
.insert_into("Results")
self.t_env.execute("test")
actual = source_sink_utils.results()
self.assert_equals(actual,
["1,32767,-2147483648,1,true,false,1.0,1.0,hello,中文,"
"[102, 108, 105, 110, 107],1000000000000000000.050000000000000000,"
"1000000000000000000.059999999999999999,2014-09-13"])
"1000000000000000000.059999999999999999,2014-09-13,01:00:01"])
class StreamPandasUDFITTests(PandasUDFITTests,
......
......@@ -31,6 +31,7 @@ import org.apache.flink.table.runtime.arrow.readers.FloatFieldReader;
import org.apache.flink.table.runtime.arrow.readers.IntFieldReader;
import org.apache.flink.table.runtime.arrow.readers.RowArrowReader;
import org.apache.flink.table.runtime.arrow.readers.SmallIntFieldReader;
import org.apache.flink.table.runtime.arrow.readers.TimeFieldReader;
import org.apache.flink.table.runtime.arrow.readers.TinyIntFieldReader;
import org.apache.flink.table.runtime.arrow.readers.VarBinaryFieldReader;
import org.apache.flink.table.runtime.arrow.readers.VarCharFieldReader;
......@@ -42,6 +43,7 @@ import org.apache.flink.table.runtime.arrow.vectors.ArrowDoubleColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowFloatColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowSmallIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowTimeColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowTinyIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowVarBinaryColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowVarCharColumnVector;
......@@ -55,6 +57,7 @@ import org.apache.flink.table.runtime.arrow.writers.BaseRowDoubleWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowFloatWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowIntWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowSmallIntWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowTimeWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowTinyIntWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowVarBinaryWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowVarCharWriter;
......@@ -66,6 +69,7 @@ import org.apache.flink.table.runtime.arrow.writers.DoubleWriter;
import org.apache.flink.table.runtime.arrow.writers.FloatWriter;
import org.apache.flink.table.runtime.arrow.writers.IntWriter;
import org.apache.flink.table.runtime.arrow.writers.SmallIntWriter;
import org.apache.flink.table.runtime.arrow.writers.TimeWriter;
import org.apache.flink.table.runtime.arrow.writers.TinyIntWriter;
import org.apache.flink.table.runtime.arrow.writers.VarBinaryWriter;
import org.apache.flink.table.runtime.arrow.writers.VarCharWriter;
......@@ -80,6 +84,7 @@ import org.apache.flink.table.types.logical.LegacyTypeInformationType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.SmallIntType;
import org.apache.flink.table.types.logical.TimeType;
import org.apache.flink.table.types.logical.TinyIntType;
import org.apache.flink.table.types.logical.VarBinaryType;
import org.apache.flink.table.types.logical.VarCharType;
......@@ -96,12 +101,17 @@ import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.SmallIntVector;
import org.apache.arrow.vector.TimeMicroVector;
import org.apache.arrow.vector.TimeMilliVector;
import org.apache.arrow.vector.TimeNanoVector;
import org.apache.arrow.vector.TimeSecVector;
import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
......@@ -178,6 +188,9 @@ public final class ArrowUtils {
return new DecimalWriter(decimalVector, getPrecision(decimalVector), decimalVector.getScale());
} else if (vector instanceof DateDayVector) {
return new DateWriter((DateDayVector) vector);
} else if (vector instanceof TimeSecVector || vector instanceof TimeMilliVector ||
vector instanceof TimeMicroVector || vector instanceof TimeNanoVector) {
return new TimeWriter(vector);
} else {
throw new UnsupportedOperationException(String.format(
"Unsupported type %s.", fieldType));
......@@ -223,6 +236,9 @@ public final class ArrowUtils {
return new BaseRowDecimalWriter(decimalVector, getPrecision(decimalVector), decimalVector.getScale());
} else if (vector instanceof DateDayVector) {
return new BaseRowDateWriter((DateDayVector) vector);
} else if (vector instanceof TimeSecVector || vector instanceof TimeMilliVector ||
vector instanceof TimeMicroVector || vector instanceof TimeNanoVector) {
return new BaseRowTimeWriter(vector);
} else {
throw new UnsupportedOperationException(String.format(
"Unsupported type %s.", fieldType));
......@@ -265,6 +281,9 @@ public final class ArrowUtils {
return new DecimalFieldReader((DecimalVector) vector);
} else if (vector instanceof DateDayVector) {
return new DateFieldReader((DateDayVector) vector);
} else if (vector instanceof TimeSecVector || vector instanceof TimeMilliVector ||
vector instanceof TimeMicroVector || vector instanceof TimeNanoVector) {
return new TimeFieldReader(vector);
} else {
throw new UnsupportedOperationException(String.format(
"Unsupported type %s.", fieldType));
......@@ -307,6 +326,9 @@ public final class ArrowUtils {
return new ArrowDecimalColumnVector((DecimalVector) vector);
} else if (vector instanceof DateDayVector) {
return new ArrowDateColumnVector((DateDayVector) vector);
} else if (vector instanceof TimeSecVector || vector instanceof TimeMilliVector ||
vector instanceof TimeMicroVector || vector instanceof TimeNanoVector) {
return new ArrowTimeColumnVector(vector);
} else {
throw new UnsupportedOperationException(String.format(
"Unsupported type %s.", fieldType));
......@@ -372,6 +394,19 @@ public final class ArrowUtils {
return new ArrowType.Date(DateUnit.DAY);
}
@Override
public ArrowType visit(TimeType timeType) {
if (timeType.getPrecision() == 0) {
return new ArrowType.Time(TimeUnit.SECOND, 32);
} else if (timeType.getPrecision() >= 1 && timeType.getPrecision() <= 3) {
return new ArrowType.Time(TimeUnit.MILLISECOND, 32);
} else if (timeType.getPrecision() >= 4 && timeType.getPrecision() <= 6) {
return new ArrowType.Time(TimeUnit.MICROSECOND, 64);
} else {
return new ArrowType.Time(TimeUnit.NANOSECOND, 64);
}
}
@Override
protected ArrowType defaultMethod(LogicalType logicalType) {
if (logicalType instanceof LegacyTypeInformationType) {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.arrow.readers;
import org.apache.flink.annotation.Internal;
import org.apache.flink.util.Preconditions;
import org.apache.arrow.vector.TimeMicroVector;
import org.apache.arrow.vector.TimeMilliVector;
import org.apache.arrow.vector.TimeNanoVector;
import org.apache.arrow.vector.TimeSecVector;
import org.apache.arrow.vector.ValueVector;
import java.sql.Time;
import java.util.TimeZone;
/**
* {@link ArrowFieldReader} for Time.
*/
@Internal
public final class TimeFieldReader extends ArrowFieldReader<Time> {
// The local time zone.
private static final TimeZone LOCAL_TZ = TimeZone.getDefault();
public TimeFieldReader(ValueVector valueVector) {
super(valueVector);
Preconditions.checkState(
valueVector instanceof TimeSecVector ||
valueVector instanceof TimeMilliVector ||
valueVector instanceof TimeMicroVector ||
valueVector instanceof TimeNanoVector);
}
@Override
public Time read(int index) {
ValueVector valueVector = getValueVector();
if (valueVector.isNull(index)) {
return null;
} else {
long timeMilli;
if (valueVector instanceof TimeSecVector) {
timeMilli = ((TimeSecVector) getValueVector()).get(index) * 1000;
} else if (valueVector instanceof TimeMilliVector) {
timeMilli = ((TimeMilliVector) getValueVector()).get(index);
} else if (valueVector instanceof TimeMicroVector) {
timeMilli = ((TimeMicroVector) getValueVector()).get(index) / 1000;
} else {
timeMilli = ((TimeNanoVector) getValueVector()).get(index) / 1000000;
}
return new Time(timeMilli - LOCAL_TZ.getOffset(timeMilli));
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.arrow.vectors;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.dataformat.vector.IntColumnVector;
import org.apache.flink.util.Preconditions;
import org.apache.arrow.vector.TimeMicroVector;
import org.apache.arrow.vector.TimeMilliVector;
import org.apache.arrow.vector.TimeNanoVector;
import org.apache.arrow.vector.TimeSecVector;
import org.apache.arrow.vector.ValueVector;
/**
* Arrow column vector for Time.
*/
@Internal
public final class ArrowTimeColumnVector implements IntColumnVector {
/**
* Container which is used to store the sequence of time values of a column to read.
*/
private final ValueVector valueVector;
public ArrowTimeColumnVector(ValueVector valueVector) {
this.valueVector = Preconditions.checkNotNull(valueVector);
Preconditions.checkState(
valueVector instanceof TimeSecVector ||
valueVector instanceof TimeMilliVector ||
valueVector instanceof TimeMicroVector ||
valueVector instanceof TimeNanoVector);
}
@Override
public int getInt(int i) {
if (valueVector instanceof TimeSecVector) {
return ((TimeSecVector) valueVector).get(i) * 1000;
} else if (valueVector instanceof TimeMilliVector) {
return ((TimeMilliVector) valueVector).get(i);
} else if (valueVector instanceof TimeMicroVector) {
return (int) (((TimeMicroVector) valueVector).get(i) / 1000);
} else {
return (int) (((TimeNanoVector) valueVector).get(i) / 1000000);
}
}
@Override
public boolean isNullAt(int i) {
return valueVector.isNull(i);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.arrow.writers;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.dataformat.BaseRow;
import org.apache.flink.util.Preconditions;
import org.apache.arrow.vector.BaseFixedWidthVector;
import org.apache.arrow.vector.TimeMicroVector;
import org.apache.arrow.vector.TimeMilliVector;
import org.apache.arrow.vector.TimeNanoVector;
import org.apache.arrow.vector.TimeSecVector;
import org.apache.arrow.vector.ValueVector;
/**
* {@link ArrowFieldWriter} for Time.
*/
@Internal
public final class BaseRowTimeWriter extends ArrowFieldWriter<BaseRow> {
public BaseRowTimeWriter(ValueVector valueVector) {
super(valueVector);
Preconditions.checkState(
valueVector instanceof TimeSecVector ||
valueVector instanceof TimeMilliVector ||
valueVector instanceof TimeMicroVector ||
valueVector instanceof TimeNanoVector);
}
@Override
public void doWrite(BaseRow row, int ordinal) {
ValueVector valueVector = getValueVector();
if (row.isNullAt(ordinal)) {
((BaseFixedWidthVector) valueVector).setNull(getCount());
} else if (valueVector instanceof TimeSecVector) {
((TimeSecVector) valueVector).setSafe(getCount(), row.getInt(ordinal) / 1000);
} else if (valueVector instanceof TimeMilliVector) {
((TimeMilliVector) valueVector).setSafe(getCount(), row.getInt(ordinal));
} else if (valueVector instanceof TimeMicroVector) {
((TimeMicroVector) valueVector).setSafe(getCount(), row.getInt(ordinal) * 1000L);
} else {
((TimeNanoVector) valueVector).setSafe(getCount(), row.getInt(ordinal) * 1000000L);
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.arrow.writers;
import org.apache.flink.annotation.Internal;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;
import org.apache.arrow.vector.BaseFixedWidthVector;
import org.apache.arrow.vector.TimeMicroVector;
import org.apache.arrow.vector.TimeMilliVector;
import org.apache.arrow.vector.TimeNanoVector;
import org.apache.arrow.vector.TimeSecVector;
import org.apache.arrow.vector.ValueVector;
import java.sql.Time;
import java.util.TimeZone;
/**
* {@link ArrowFieldWriter} for Time.
*/
@Internal
public final class TimeWriter extends ArrowFieldWriter<Row> {
// The local time zone.
private static final TimeZone LOCAL_TZ = TimeZone.getDefault();
private static final long MILLIS_PER_DAY = 86400000L; // = 24 * 60 * 60 * 1000
public TimeWriter(ValueVector valueVector) {
super(valueVector);
Preconditions.checkState(
valueVector instanceof TimeSecVector ||
valueVector instanceof TimeMilliVector ||
valueVector instanceof TimeMicroVector ||
valueVector instanceof TimeNanoVector);
}
@Override
public void doWrite(Row row, int ordinal) {
ValueVector valueVector = getValueVector();
if (row.getField(ordinal) == null) {
((BaseFixedWidthVector) getValueVector()).setNull(getCount());
} else {
Time time = (Time) row.getField(ordinal);
long ts = time.getTime() + LOCAL_TZ.getOffset(time.getTime());
int timeMilli = (int) (ts % MILLIS_PER_DAY);
if (valueVector instanceof TimeSecVector) {
((TimeSecVector) valueVector).setSafe(getCount(), timeMilli / 1000);
} else if (valueVector instanceof TimeMilliVector) {
((TimeMilliVector) valueVector).setSafe(getCount(), timeMilli);
} else if (valueVector instanceof TimeMicroVector) {
((TimeMicroVector) valueVector).setSafe(getCount(), timeMilli * 1000L);
} else {
((TimeNanoVector) valueVector).setSafe(getCount(), timeMilli * 1000000L);
}
}
}
}
......@@ -31,6 +31,7 @@ import org.apache.flink.table.runtime.arrow.readers.FloatFieldReader;
import org.apache.flink.table.runtime.arrow.readers.IntFieldReader;
import org.apache.flink.table.runtime.arrow.readers.RowArrowReader;
import org.apache.flink.table.runtime.arrow.readers.SmallIntFieldReader;
import org.apache.flink.table.runtime.arrow.readers.TimeFieldReader;
import org.apache.flink.table.runtime.arrow.readers.TinyIntFieldReader;
import org.apache.flink.table.runtime.arrow.readers.VarBinaryFieldReader;
import org.apache.flink.table.runtime.arrow.readers.VarCharFieldReader;
......@@ -42,6 +43,7 @@ import org.apache.flink.table.runtime.arrow.vectors.ArrowDoubleColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowFloatColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowSmallIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowTimeColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowTinyIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowVarBinaryColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowVarCharColumnVector;
......@@ -55,6 +57,7 @@ import org.apache.flink.table.runtime.arrow.writers.BaseRowDoubleWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowFloatWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowIntWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowSmallIntWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowTimeWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowTinyIntWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowVarBinaryWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowVarCharWriter;
......@@ -66,6 +69,7 @@ import org.apache.flink.table.runtime.arrow.writers.DoubleWriter;
import org.apache.flink.table.runtime.arrow.writers.FloatWriter;
import org.apache.flink.table.runtime.arrow.writers.IntWriter;
import org.apache.flink.table.runtime.arrow.writers.SmallIntWriter;
import org.apache.flink.table.runtime.arrow.writers.TimeWriter;
import org.apache.flink.table.runtime.arrow.writers.TinyIntWriter;
import org.apache.flink.table.runtime.arrow.writers.VarBinaryWriter;
import org.apache.flink.table.runtime.arrow.writers.VarCharWriter;
......@@ -79,6 +83,7 @@ import org.apache.flink.table.types.logical.IntType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.SmallIntType;
import org.apache.flink.table.types.logical.TimeType;
import org.apache.flink.table.types.logical.TinyIntType;
import org.apache.flink.table.types.logical.VarBinaryType;
import org.apache.flink.table.types.logical.VarCharType;
......@@ -88,6 +93,7 @@ import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
......@@ -134,6 +140,14 @@ public class ArrowUtilsTest {
DecimalWriter.class, BaseRowDecimalWriter.class, DecimalFieldReader.class, ArrowDecimalColumnVector.class));
testFields.add(Tuple7.of("f11", new DateType(), new ArrowType.Date(DateUnit.DAY),
DateWriter.class, BaseRowDateWriter.class, DateFieldReader.class, ArrowDateColumnVector.class));
testFields.add(Tuple7.of("f13", new TimeType(0), new ArrowType.Time(TimeUnit.SECOND, 32),
TimeWriter.class, BaseRowTimeWriter.class, TimeFieldReader.class, ArrowTimeColumnVector.class));
testFields.add(Tuple7.of("f14", new TimeType(2), new ArrowType.Time(TimeUnit.MILLISECOND, 32),
TimeWriter.class, BaseRowTimeWriter.class, TimeFieldReader.class, ArrowTimeColumnVector.class));
testFields.add(Tuple7.of("f15", new TimeType(4), new ArrowType.Time(TimeUnit.MICROSECOND, 64),
TimeWriter.class, BaseRowTimeWriter.class, TimeFieldReader.class, ArrowTimeColumnVector.class));
testFields.add(Tuple7.of("f16", new TimeType(8), new ArrowType.Time(TimeUnit.NANOSECOND, 64),
TimeWriter.class, BaseRowTimeWriter.class, TimeFieldReader.class, ArrowTimeColumnVector.class));
List<RowType.RowField> rowFields = new ArrayList<>();
for (Tuple7<String, LogicalType, ArrowType, Class<?>, Class<?>, Class<?>, Class<?>> field : testFields) {
......
......@@ -35,6 +35,7 @@ import org.apache.flink.table.types.logical.IntType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.SmallIntType;
import org.apache.flink.table.types.logical.TimeType;
import org.apache.flink.table.types.logical.TinyIntType;
import org.apache.flink.table.types.logical.VarBinaryType;
import org.apache.flink.table.types.logical.VarCharType;
......@@ -101,6 +102,10 @@ public class BaseRowArrowReaderWriterTest extends ArrowReaderWriterTestBase<Base
fieldTypes.add(new VarBinaryType());
fieldTypes.add(new DecimalType(10, 3));
fieldTypes.add(new DateType());
fieldTypes.add(new TimeType(0));
fieldTypes.add(new TimeType(2));
fieldTypes.add(new TimeType(4));
fieldTypes.add(new TimeType(8));
List<RowType.RowField> rowFields = new ArrayList<>();
for (int i = 0; i < fieldTypes.size(); i++) {
......@@ -128,12 +133,12 @@ public class BaseRowArrowReaderWriterTest extends ArrowReaderWriterTestBase<Base
@Override
public BaseRow[] getTestData() {
BaseRow row1 = StreamRecordUtils.baserow((byte) 1, (short) 2, 3, 4L, true, 1.0f, 1.0, "hello", "hello".getBytes(), Decimal.fromLong(1, 10, 3), 100);
BinaryRow row2 = StreamRecordUtils.binaryrow((byte) 1, (short) 2, 3, 4L, false, 1.0f, 1.0, "中文", "中文".getBytes(), Decimal.fromLong(1, 10, 3), 100);
BaseRow row3 = StreamRecordUtils.baserow(null, (short) 2, 3, 4L, false, 1.0f, 1.0, "中文", "中文".getBytes(), Decimal.fromLong(1, 10, 3), 100);
BinaryRow row4 = StreamRecordUtils.binaryrow((byte) 1, null, 3, 4L, true, 1.0f, 1.0, "hello", "hello".getBytes(), Decimal.fromLong(1, 10, 3), 100);
BaseRow row5 = StreamRecordUtils.baserow(null, null, null, null, null, null, null, null, null, null, null);
BinaryRow row6 = StreamRecordUtils.binaryrow(null, null, null, null, null, null, null, null, null, null, null);
BaseRow row1 = StreamRecordUtils.baserow((byte) 1, (short) 2, 3, 4L, true, 1.0f, 1.0, "hello", "hello".getBytes(), Decimal.fromLong(1, 10, 3), 100, 3600000, 3600000, 3600000, 3600000);
BinaryRow row2 = StreamRecordUtils.binaryrow((byte) 1, (short) 2, 3, 4L, false, 1.0f, 1.0, "中文", "中文".getBytes(), Decimal.fromLong(1, 10, 3), 100, 3600000, 3600000, 3600000, 3600000);
BaseRow row3 = StreamRecordUtils.baserow(null, (short) 2, 3, 4L, false, 1.0f, 1.0, "中文", "中文".getBytes(), Decimal.fromLong(1, 10, 3), 100, 3600000, 3600000, 3600000, 3600000);
BinaryRow row4 = StreamRecordUtils.binaryrow((byte) 1, null, 3, 4L, true, 1.0f, 1.0, "hello", "hello".getBytes(), Decimal.fromLong(1, 10, 3), 100, 3600000, 3600000, 3600000, 3600000);
BaseRow row5 = StreamRecordUtils.baserow(null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);
BinaryRow row6 = StreamRecordUtils.binaryrow(null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);
return new BaseRow[]{row1, row2, row3, row4, row5, row6};
}
}
......@@ -30,6 +30,7 @@ import org.apache.flink.table.types.logical.IntType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.SmallIntType;
import org.apache.flink.table.types.logical.TimeType;
import org.apache.flink.table.types.logical.TinyIntType;
import org.apache.flink.table.types.logical.VarBinaryType;
import org.apache.flink.table.types.logical.VarCharType;
......@@ -69,6 +70,10 @@ public class RowArrowReaderWriterTest extends ArrowReaderWriterTestBase<Row> {
fieldTypes.add(new VarBinaryType());
fieldTypes.add(new DecimalType(10, 0));
fieldTypes.add(new DateType());
fieldTypes.add(new TimeType(0));
fieldTypes.add(new TimeType(2));
fieldTypes.add(new TimeType(4));
fieldTypes.add(new TimeType(8));
List<RowType.RowField> rowFields = new ArrayList<>();
for (int i = 0; i < fieldTypes.size(); i++) {
......@@ -96,10 +101,13 @@ public class RowArrowReaderWriterTest extends ArrowReaderWriterTestBase<Row> {
@Override
public Row[] getTestData() {
Row row1 = Row.of((byte) 1, (short) 2, 3, 4L, true, 1.0f, 1.0, "hello", "hello".getBytes(), new BigDecimal(1), SqlDateTimeUtils.internalToDate(100));
Row row2 = Row.of(null, (short) 2, 3, 4L, false, 1.0f, 1.0, "中文", "中文".getBytes(), new BigDecimal(1), SqlDateTimeUtils.internalToDate(100));
Row row3 = Row.of((byte) 1, null, 3, 4L, true, 1.0f, 1.0, "hello", "hello".getBytes(), new BigDecimal(1), SqlDateTimeUtils.internalToDate(100));
Row row4 = Row.of(null, null, null, null, null, null, null, null, null, null, null);
Row row1 = Row.of((byte) 1, (short) 2, 3, 4L, true, 1.0f, 1.0, "hello", "hello".getBytes(), new BigDecimal(1), SqlDateTimeUtils.internalToDate(100),
SqlDateTimeUtils.internalToTime(3600000), SqlDateTimeUtils.internalToTime(3600000), SqlDateTimeUtils.internalToTime(3600000), SqlDateTimeUtils.internalToTime(3600000));
Row row2 = Row.of(null, (short) 2, 3, 4L, false, 1.0f, 1.0, "中文", "中文".getBytes(), new BigDecimal(1), SqlDateTimeUtils.internalToDate(100),
SqlDateTimeUtils.internalToTime(3600000), SqlDateTimeUtils.internalToTime(3600000), SqlDateTimeUtils.internalToTime(3600000), SqlDateTimeUtils.internalToTime(3600000));
Row row3 = Row.of((byte) 1, null, 3, 4L, true, 1.0f, 1.0, "hello", "hello".getBytes(), new BigDecimal(1), SqlDateTimeUtils.internalToDate(100),
SqlDateTimeUtils.internalToTime(3600000), SqlDateTimeUtils.internalToTime(3600000), SqlDateTimeUtils.internalToTime(3600000), SqlDateTimeUtils.internalToTime(3600000));
Row row4 = Row.of(null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);
return new Row[]{row1, row2, row3, row4};
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册