提交 cf1d8c63 编写于 作者: D Dian Fu 提交者: Hequn Cheng

[FLINK-16608][python] Support VarCharType in vectorized Python UDF

上级 180c2c9b
......@@ -428,6 +428,8 @@ class ArrowCoder(DeterministicCoder):
return pa.field(field.name, pa.float32(), field.type.nullable)
elif field.type.type_name == flink_fn_execution_pb2.Schema.TypeName.DOUBLE:
return pa.field(field.name, pa.float64(), field.type.nullable)
elif field.type.type_name == flink_fn_execution_pb2.Schema.TypeName.VARCHAR:
return pa.field(field.name, pa.utf8(), field.type.nullable)
else:
raise ValueError("field_type %s is not supported." % field.type)
......
......@@ -109,6 +109,12 @@ class PandasUDFITTests(object):
'double_param of wrong type %s !' % type(double_param[0])
return double_param
def varchar_func(varchar_param):
assert isinstance(varchar_param, pd.Series)
assert isinstance(varchar_param[0], str), \
'varchar_param of wrong type %s !' % type(varchar_param[0])
return varchar_param
self.t_env.register_function(
"tinyint_func",
udf(tinyint_func, [DataTypes.TINYINT()], DataTypes.TINYINT(), udf_type="pandas"))
......@@ -137,14 +143,19 @@ class PandasUDFITTests(object):
"double_func",
udf(double_func, [DataTypes.DOUBLE()], DataTypes.DOUBLE(), udf_type="pandas"))
self.t_env.register_function(
"varchar_func",
udf(varchar_func, [DataTypes.STRING()], DataTypes.STRING(), udf_type="pandas"))
table_sink = source_sink_utils.TestAppendSink(
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'],
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'],
[DataTypes.TINYINT(), DataTypes.SMALLINT(), DataTypes.INT(), DataTypes.BIGINT(),
DataTypes.BOOLEAN(), DataTypes.BOOLEAN(), DataTypes.FLOAT(), DataTypes.DOUBLE()])
DataTypes.BOOLEAN(), DataTypes.BOOLEAN(), DataTypes.FLOAT(), DataTypes.DOUBLE(),
DataTypes.STRING(), DataTypes.STRING()])
self.t_env.register_table_sink("Results", table_sink)
t = self.t_env.from_elements(
[(1, 32767, -2147483648, 1, True, False, 1.0, 1.0)],
[(1, 32767, -2147483648, 1, True, False, 1.0, 1.0, 'hello', '中文')],
DataTypes.ROW(
[DataTypes.FIELD("a", DataTypes.TINYINT()),
DataTypes.FIELD("b", DataTypes.SMALLINT()),
......@@ -153,7 +164,9 @@ class PandasUDFITTests(object):
DataTypes.FIELD("e", DataTypes.BOOLEAN()),
DataTypes.FIELD("f", DataTypes.BOOLEAN()),
DataTypes.FIELD("g", DataTypes.FLOAT()),
DataTypes.FIELD("h", DataTypes.DOUBLE())]))
DataTypes.FIELD("h", DataTypes.DOUBLE()),
DataTypes.FIELD("i", DataTypes.STRING()),
DataTypes.FIELD("j", DataTypes.STRING())]))
t.select("tinyint_func(a),"
"smallint_func(b),"
......@@ -162,12 +175,14 @@ class PandasUDFITTests(object):
"boolean_func(e),"
"boolean_func(f),"
"float_func(g),"
"double_func(h)") \
"double_func(h),"
"varchar_func(i),"
"varchar_func(j)") \
.insert_into("Results")
self.t_env.execute("test")
actual = source_sink_utils.results()
self.assert_equals(actual,
["1,32767,-2147483648,1,true,false,1.0,1.0"])
["1,32767,-2147483648,1,true,false,1.0,1.0,hello,中文"])
class StreamPandasUDFITTests(PandasUDFITTests,
......
......@@ -30,6 +30,7 @@ import org.apache.flink.table.runtime.arrow.readers.IntFieldReader;
import org.apache.flink.table.runtime.arrow.readers.RowArrowReader;
import org.apache.flink.table.runtime.arrow.readers.SmallIntFieldReader;
import org.apache.flink.table.runtime.arrow.readers.TinyIntFieldReader;
import org.apache.flink.table.runtime.arrow.readers.VarCharFieldReader;
import org.apache.flink.table.runtime.arrow.vectors.ArrowBigIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowBooleanColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowDoubleColumnVector;
......@@ -37,6 +38,7 @@ import org.apache.flink.table.runtime.arrow.vectors.ArrowFloatColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowSmallIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowTinyIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowVarCharColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.BaseRowArrowReader;
import org.apache.flink.table.runtime.arrow.writers.ArrowFieldWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowBigIntWriter;
......@@ -46,6 +48,7 @@ import org.apache.flink.table.runtime.arrow.writers.BaseRowFloatWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowIntWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowSmallIntWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowTinyIntWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowVarCharWriter;
import org.apache.flink.table.runtime.arrow.writers.BigIntWriter;
import org.apache.flink.table.runtime.arrow.writers.BooleanWriter;
import org.apache.flink.table.runtime.arrow.writers.DoubleWriter;
......@@ -53,6 +56,7 @@ import org.apache.flink.table.runtime.arrow.writers.FloatWriter;
import org.apache.flink.table.runtime.arrow.writers.IntWriter;
import org.apache.flink.table.runtime.arrow.writers.SmallIntWriter;
import org.apache.flink.table.runtime.arrow.writers.TinyIntWriter;
import org.apache.flink.table.runtime.arrow.writers.VarCharWriter;
import org.apache.flink.table.types.logical.BigIntType;
import org.apache.flink.table.types.logical.BooleanType;
import org.apache.flink.table.types.logical.DoubleType;
......@@ -62,6 +66,7 @@ import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.SmallIntType;
import org.apache.flink.table.types.logical.TinyIntType;
import org.apache.flink.table.types.logical.VarCharType;
import org.apache.flink.table.types.logical.utils.LogicalTypeDefaultVisitor;
import org.apache.flink.types.Row;
......@@ -74,6 +79,7 @@ import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.SmallIntVector;
import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.pojo.ArrowType;
......@@ -142,6 +148,8 @@ public final class ArrowUtils {
return new FloatWriter((Float4Vector) vector);
} else if (vector instanceof Float8Vector) {
return new DoubleWriter((Float8Vector) vector);
} else if (vector instanceof VarCharVector) {
return new VarCharWriter((VarCharVector) vector);
} else {
throw new UnsupportedOperationException(String.format(
"Unsupported type %s.", fieldType));
......@@ -178,6 +186,8 @@ public final class ArrowUtils {
return new BaseRowFloatWriter((Float4Vector) vector);
} else if (vector instanceof Float8Vector) {
return new BaseRowDoubleWriter((Float8Vector) vector);
} else if (vector instanceof VarCharVector) {
return new BaseRowVarCharWriter((VarCharVector) vector);
} else {
throw new UnsupportedOperationException(String.format(
"Unsupported type %s.", fieldType));
......@@ -212,6 +222,8 @@ public final class ArrowUtils {
return new FloatFieldReader((Float4Vector) vector);
} else if (vector instanceof Float8Vector) {
return new DoubleFieldReader((Float8Vector) vector);
} else if (vector instanceof VarCharVector) {
return new VarCharFieldReader((VarCharVector) vector);
} else {
throw new UnsupportedOperationException(String.format(
"Unsupported type %s.", fieldType));
......@@ -246,6 +258,8 @@ public final class ArrowUtils {
return new ArrowFloatColumnVector((Float4Vector) vector);
} else if (vector instanceof Float8Vector) {
return new ArrowDoubleColumnVector((Float8Vector) vector);
} else if (vector instanceof VarCharVector) {
return new ArrowVarCharColumnVector((VarCharVector) vector);
} else {
throw new UnsupportedOperationException(String.format(
"Unsupported type %s.", fieldType));
......@@ -291,6 +305,11 @@ public final class ArrowUtils {
return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE);
}
@Override
public ArrowType visit(VarCharType varCharType) {
return ArrowType.Utf8.INSTANCE;
}
@Override
protected ArrowType defaultMethod(LogicalType logicalType) {
throw new UnsupportedOperationException(String.format(
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.arrow.readers;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.runtime.util.StringUtf8Utils;
import org.apache.arrow.vector.VarCharVector;
/**
* {@link ArrowFieldReader} for VarChar.
*/
@Internal
public final class VarCharFieldReader extends ArrowFieldReader<String> {
public VarCharFieldReader(VarCharVector varCharVector) {
super(varCharVector);
}
@Override
public String read(int index) {
if (getValueVector().isNull(index)) {
return null;
} else {
byte[] bytes = ((VarCharVector) getValueVector()).get(index);
return StringUtf8Utils.decodeUTF8(bytes, 0, bytes.length);
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.arrow.vectors;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.dataformat.vector.BytesColumnVector;
import org.apache.flink.util.Preconditions;
import org.apache.arrow.vector.VarCharVector;
/**
* Arrow column vector for VarChar.
*/
@Internal
public final class ArrowVarCharColumnVector implements BytesColumnVector {
/**
* Container which is used to store the sequence of varchar values of a column to read.
*/
private final VarCharVector varCharVector;
public ArrowVarCharColumnVector(VarCharVector varCharVector) {
this.varCharVector = Preconditions.checkNotNull(varCharVector);
}
@Override
public Bytes getBytes(int i) {
byte[] bytes = varCharVector.get(i);
return new Bytes(bytes, 0, bytes.length);
}
@Override
public boolean isNullAt(int i) {
return varCharVector.isNull(i);
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.arrow.writers;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.dataformat.BaseRow;
import org.apache.arrow.vector.VarCharVector;
/**
* {@link ArrowFieldWriter} for VarChar.
*/
@Internal
public final class BaseRowVarCharWriter extends ArrowFieldWriter<BaseRow> {
public BaseRowVarCharWriter(VarCharVector varCharVector) {
super(varCharVector);
}
@Override
public void doWrite(BaseRow row, int ordinal) {
if (row.isNullAt(ordinal)) {
((VarCharVector) getValueVector()).setNull(getCount());
} else {
((VarCharVector) getValueVector()).setSafe(getCount(), row.getString(ordinal).getBytes());
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.arrow.writers;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.runtime.util.StringUtf8Utils;
import org.apache.flink.types.Row;
import org.apache.arrow.vector.VarCharVector;
/**
* {@link ArrowFieldWriter} for VarChar.
*/
@Internal
public final class VarCharWriter extends ArrowFieldWriter<Row> {
public VarCharWriter(VarCharVector varCharVector) {
super(varCharVector);
}
@Override
public void doWrite(Row value, int ordinal) {
if (value.getField(ordinal) == null) {
((VarCharVector) getValueVector()).setNull(getCount());
} else {
((VarCharVector) getValueVector()).setSafe(
getCount(), StringUtf8Utils.encodeUTF8(((String) value.getField(ordinal))));
}
}
}
......@@ -30,6 +30,7 @@ import org.apache.flink.table.runtime.arrow.readers.IntFieldReader;
import org.apache.flink.table.runtime.arrow.readers.RowArrowReader;
import org.apache.flink.table.runtime.arrow.readers.SmallIntFieldReader;
import org.apache.flink.table.runtime.arrow.readers.TinyIntFieldReader;
import org.apache.flink.table.runtime.arrow.readers.VarCharFieldReader;
import org.apache.flink.table.runtime.arrow.vectors.ArrowBigIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowBooleanColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowDoubleColumnVector;
......@@ -37,6 +38,7 @@ import org.apache.flink.table.runtime.arrow.vectors.ArrowFloatColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowSmallIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowTinyIntColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.ArrowVarCharColumnVector;
import org.apache.flink.table.runtime.arrow.vectors.BaseRowArrowReader;
import org.apache.flink.table.runtime.arrow.writers.ArrowFieldWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowBigIntWriter;
......@@ -46,6 +48,7 @@ import org.apache.flink.table.runtime.arrow.writers.BaseRowFloatWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowIntWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowSmallIntWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowTinyIntWriter;
import org.apache.flink.table.runtime.arrow.writers.BaseRowVarCharWriter;
import org.apache.flink.table.runtime.arrow.writers.BigIntWriter;
import org.apache.flink.table.runtime.arrow.writers.BooleanWriter;
import org.apache.flink.table.runtime.arrow.writers.DoubleWriter;
......@@ -53,6 +56,7 @@ import org.apache.flink.table.runtime.arrow.writers.FloatWriter;
import org.apache.flink.table.runtime.arrow.writers.IntWriter;
import org.apache.flink.table.runtime.arrow.writers.SmallIntWriter;
import org.apache.flink.table.runtime.arrow.writers.TinyIntWriter;
import org.apache.flink.table.runtime.arrow.writers.VarCharWriter;
import org.apache.flink.table.types.logical.BigIntType;
import org.apache.flink.table.types.logical.BooleanType;
import org.apache.flink.table.types.logical.DoubleType;
......@@ -62,6 +66,7 @@ import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.SmallIntType;
import org.apache.flink.table.types.logical.TinyIntType;
import org.apache.flink.table.types.logical.VarCharType;
import org.apache.flink.types.Row;
import org.apache.arrow.memory.BufferAllocator;
......@@ -105,6 +110,8 @@ public class ArrowUtilsTest {
FloatWriter.class, BaseRowFloatWriter.class, FloatFieldReader.class, ArrowFloatColumnVector.class));
testFields.add(Tuple7.of("f7", new DoubleType(), new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE),
DoubleWriter.class, BaseRowDoubleWriter.class, DoubleFieldReader.class, ArrowDoubleColumnVector.class));
testFields.add(Tuple7.of("f8", new VarCharType(), ArrowType.Utf8.INSTANCE,
VarCharWriter.class, BaseRowVarCharWriter.class, VarCharFieldReader.class, ArrowVarCharColumnVector.class));
List<RowType.RowField> rowFields = new ArrayList<>();
for (Tuple7<String, LogicalType, ArrowType, Class<?>, Class<?>, Class<?>, Class<?>> field : testFields) {
......
......@@ -33,6 +33,7 @@ import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.SmallIntType;
import org.apache.flink.table.types.logical.TinyIntType;
import org.apache.flink.table.types.logical.VarCharType;
import org.apache.flink.testutils.DeeplyEqualsChecker;
import org.apache.arrow.memory.BufferAllocator;
......@@ -92,6 +93,7 @@ public class BaseRowArrowReaderWriterTest extends ArrowReaderWriterTestBase<Base
fieldTypes.add(new BooleanType());
fieldTypes.add(new FloatType());
fieldTypes.add(new DoubleType());
fieldTypes.add(new VarCharType());
List<RowType.RowField> rowFields = new ArrayList<>();
for (int i = 0; i < fieldTypes.size(); i++) {
......@@ -119,12 +121,12 @@ public class BaseRowArrowReaderWriterTest extends ArrowReaderWriterTestBase<Base
@Override
public BaseRow[] getTestData() {
BaseRow row1 = StreamRecordUtils.baserow((byte) 1, (short) 2, 3, 4L, true, 1.0f, 1.0);
BinaryRow row2 = StreamRecordUtils.binaryrow((byte) 1, (short) 2, 3, 4L, false, 1.0f, 1.0);
BaseRow row3 = StreamRecordUtils.baserow(null, (short) 2, 3, 4L, false, 1.0f, 1.0);
BinaryRow row4 = StreamRecordUtils.binaryrow((byte) 1, null, 3, 4L, true, 1.0f, 1.0);
BaseRow row5 = StreamRecordUtils.baserow(null, null, null, null, null, null, null);
BinaryRow row6 = StreamRecordUtils.binaryrow(null, null, null, null, null, null, null);
BaseRow row1 = StreamRecordUtils.baserow((byte) 1, (short) 2, 3, 4L, true, 1.0f, 1.0, "hello");
BinaryRow row2 = StreamRecordUtils.binaryrow((byte) 1, (short) 2, 3, 4L, false, 1.0f, 1.0, "中文");
BaseRow row3 = StreamRecordUtils.baserow(null, (short) 2, 3, 4L, false, 1.0f, 1.0, "中文");
BinaryRow row4 = StreamRecordUtils.binaryrow((byte) 1, null, 3, 4L, true, 1.0f, 1.0, "hello");
BaseRow row5 = StreamRecordUtils.baserow(null, null, null, null, null, null, null, null);
BinaryRow row6 = StreamRecordUtils.binaryrow(null, null, null, null, null, null, null, null);
return new BaseRow[]{row1, row2, row3, row4, row5, row6};
}
}
......@@ -28,6 +28,7 @@ import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.SmallIntType;
import org.apache.flink.table.types.logical.TinyIntType;
import org.apache.flink.table.types.logical.VarCharType;
import org.apache.flink.types.Row;
import org.apache.arrow.memory.BufferAllocator;
......@@ -59,6 +60,7 @@ public class RowArrowReaderWriterTest extends ArrowReaderWriterTestBase<Row> {
fieldTypes.add(new BooleanType());
fieldTypes.add(new FloatType());
fieldTypes.add(new DoubleType());
fieldTypes.add(new VarCharType());
List<RowType.RowField> rowFields = new ArrayList<>();
for (int i = 0; i < fieldTypes.size(); i++) {
......@@ -86,10 +88,10 @@ public class RowArrowReaderWriterTest extends ArrowReaderWriterTestBase<Row> {
@Override
public Row[] getTestData() {
Row row1 = Row.of((byte) 1, (short) 2, 3, 4L, true, 1.0f, 1.0);
Row row2 = Row.of(null, (short) 2, 3, 4L, false, 1.0f, 1.0);
Row row3 = Row.of((byte) 1, null, 3, 4L, true, 1.0f, 1.0);
Row row4 = Row.of(null, null, null, null, null, null, null);
Row row1 = Row.of((byte) 1, (short) 2, 3, 4L, true, 1.0f, 1.0, "hello");
Row row2 = Row.of(null, (short) 2, 3, 4L, false, 1.0f, 1.0, "中文");
Row row3 = Row.of((byte) 1, null, 3, 4L, true, 1.0f, 1.0, "hello");
Row row4 = Row.of(null, null, null, null, null, null, null, null);
return new Row[]{row1, row2, row3, row4};
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册