提交 2f9a28ae 编写于 作者: T twalthr

[FLINK-3901] [table] Convert Java implementation to Scala and fix bugs

This closes #2283.
上级 c5d1d123
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHRow WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.io;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.io.ParseException;
import org.apache.flink.api.table.Row;
import org.apache.flink.api.table.typeutils.RowTypeInfo;
import org.apache.flink.core.fs.Path;
import org.apache.flink.types.parser.FieldParser;
import org.apache.flink.types.parser.FieldParser.ParseErrorState;
@Internal
public class RowCsvInputFormat extends CsvInputFormat<Row> {
private static final long serialVersionUID = 1L;
private int arity;
public RowCsvInputFormat(Path filePath, RowTypeInfo rowTypeInfo) {
this(filePath, DEFAULT_LINE_DELIMITER, DEFAULT_FIELD_DELIMITER, rowTypeInfo);
}
public RowCsvInputFormat(Path filePath, String lineDelimiter, String fieldDelimiter, RowTypeInfo rowTypeInfo) {
this(filePath, lineDelimiter, fieldDelimiter, rowTypeInfo, createDefaultMask(rowTypeInfo.getArity()));
}
public RowCsvInputFormat(Path filePath, RowTypeInfo rowTypeInfo, int[] includedFieldsMask) {
this(filePath, DEFAULT_LINE_DELIMITER, DEFAULT_FIELD_DELIMITER, rowTypeInfo, includedFieldsMask);
}
public RowCsvInputFormat(Path filePath, String lineDelimiter, String fieldDelimiter, RowTypeInfo rowTypeInfo,
int[] includedFieldsMask) {
this(filePath, lineDelimiter, fieldDelimiter, rowTypeInfo, (includedFieldsMask == null) ? createDefaultMask(rowTypeInfo.getArity())
: toBooleanMask(includedFieldsMask));
}
public RowCsvInputFormat(Path filePath, RowTypeInfo rowTypeInfo, boolean[] includedFieldsMask) {
this(filePath, DEFAULT_LINE_DELIMITER, DEFAULT_FIELD_DELIMITER, rowTypeInfo, includedFieldsMask);
}
public RowCsvInputFormat(Path filePath, String lineDelimiter, String fieldDelimiter, RowTypeInfo rowTypeInfo,
boolean[] includedFieldsMask) {
super(filePath);
if (rowTypeInfo.getArity() == 0) {
throw new IllegalArgumentException("Row arity must be greater than 0.");
}
if (includedFieldsMask == null) {
includedFieldsMask = createDefaultMask(rowTypeInfo.getArity());
}
this.arity = rowTypeInfo.getArity();
setDelimiter(lineDelimiter);
setFieldDelimiter(fieldDelimiter);
Class<?>[] classes = new Class<?>[rowTypeInfo.getArity()];
for (int i = 0; i < rowTypeInfo.getArity(); i++) {
classes[i] = rowTypeInfo.getTypeAt(i).getTypeClass();
}
setFieldsGeneric(includedFieldsMask, classes);
}
@Override
public Row fillRecord(Row reuse, Object[] parsedValues) {
if (reuse == null) {
reuse = new Row(arity);
}
for (int i = 0; i < parsedValues.length; i++) {
reuse.setField(i, parsedValues[i]);
}
return reuse;
}
@Override
protected boolean parseRecord(Object[] holders, byte[] bytes, int offset, int numBytes) throws ParseException {
boolean[] fieldIncluded = this.fieldIncluded;
int startPos = offset;
final int limit = offset + numBytes;
for (int field = 0, output = 0; field < fieldIncluded.length; field++) {
// check valid start position
if (startPos >= limit) {
if (isLenient()) {
return false;
} else {
throw new ParseException("Row too short: " + new String(bytes, offset, numBytes));
}
}
if (fieldIncluded[field]) {
// parse field
@SuppressWarnings("unchecked")
FieldParser<Object> parser = (FieldParser<Object>) this.getFieldParsers()[output];
int latestValidPos = startPos;
startPos = parser.resetErrorStateAndParse(bytes, startPos, limit, this.getFieldDelimiter(), holders[output]);
if (!isLenient() && parser.getErrorState() != ParseErrorState.NONE) {
// Row is able to handle null values
if (parser.getErrorState() != ParseErrorState.EMPTY_STRING) {
throw new ParseException(
String.format("Parsing error for column %s of row '%s' originated by %s: %s.", field,
new String(bytes, offset, numBytes),
parser.getClass().getSimpleName(), parser.getErrorState()));
}
}
holders[output] = parser.getLastResult();
// check parse result
if (startPos < 0) {
holders[output] = null;
startPos = skipFields(bytes, latestValidPos, limit, this.getFieldDelimiter());
}
output++;
} else {
// skip field
startPos = skipFields(bytes, startPos, limit, this.getFieldDelimiter());
}
}
return true;
}
}
......@@ -19,14 +19,14 @@
package org.apache.flink.api.table.plan.nodes.dataset
import com.google.common.collect.ImmutableList
import org.apache.calcite.plan.{RelTraitSet, RelOptCluster}
import org.apache.calcite.rel.{RelWriter, RelNode}
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.{RelNode, RelWriter}
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.Values
import org.apache.calcite.rex.RexLiteral
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet
import org.apache.flink.api.table.runtime.ValuesInputFormat
import org.apache.flink.api.table.runtime.io.ValuesInputFormat
import org.apache.flink.api.table.typeutils.RowTypeInfo
import org.apache.flink.api.table.typeutils.TypeConverter._
import org.apache.flink.api.table.{BatchTableEnvironment, Row}
......
......@@ -25,8 +25,8 @@ import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.Values
import org.apache.calcite.rex.RexLiteral
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.table.runtime.io.ValuesInputFormat
import org.apache.flink.api.table.{Row, StreamTableEnvironment}
import org.apache.flink.api.table.runtime.ValuesInputFormat
import org.apache.flink.api.table.typeutils.RowTypeInfo
import org.apache.flink.api.table.typeutils.TypeConverter._
import org.apache.flink.streaming.api.datastream.DataStream
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.table.runtime.io
import org.apache.flink.annotation.Internal
import org.apache.flink.api.common.io.ParseException
import org.apache.flink.api.java.io.CsvInputFormat
import org.apache.flink.api.java.io.CsvInputFormat.{DEFAULT_FIELD_DELIMITER, DEFAULT_LINE_DELIMITER, createDefaultMask, toBooleanMask}
import org.apache.flink.api.table.Row
import org.apache.flink.api.table.runtime.io.RowCsvInputFormat.extractTypeClasses
import org.apache.flink.api.table.typeutils.RowTypeInfo
import org.apache.flink.core.fs.Path
import org.apache.flink.types.parser.FieldParser
import org.apache.flink.types.parser.FieldParser.ParseErrorState
@Internal
@SerialVersionUID(1L)
class RowCsvInputFormat(
filePath: Path,
rowTypeInfo: RowTypeInfo,
lineDelimiter: String = DEFAULT_LINE_DELIMITER,
fieldDelimiter: String = DEFAULT_FIELD_DELIMITER,
includedFieldsMask: Array[Boolean] = null)
extends CsvInputFormat[Row](filePath) {
if (rowTypeInfo.getArity == 0) {
throw new IllegalArgumentException("Row arity must be greater than 0.")
}
private val arity = rowTypeInfo.getArity
private lazy val defaultFieldMask = createDefaultMask(arity)
private val fieldsMask = Option(includedFieldsMask).getOrElse(defaultFieldMask)
// prepare CsvInputFormat
setDelimiter(lineDelimiter)
setFieldDelimiter(fieldDelimiter)
setFieldsGeneric(fieldsMask, extractTypeClasses(rowTypeInfo))
def this(
filePath: Path,
rowTypeInfo: RowTypeInfo,
lineDelimiter: String,
fieldDelimiter: String,
includedFieldsMask: Array[Int]) {
this(
filePath,
rowTypeInfo,
lineDelimiter,
fieldDelimiter,
if (includedFieldsMask == null) {
null
} else {
toBooleanMask(includedFieldsMask)
})
}
def this(
filePath: Path,
rowTypeInfo: RowTypeInfo,
includedFieldsMask: Array[Int]) {
this(
filePath,
rowTypeInfo,
DEFAULT_LINE_DELIMITER,
DEFAULT_FIELD_DELIMITER,
includedFieldsMask)
}
def fillRecord(reuse: Row, parsedValues: Array[AnyRef]): Row = {
val reuseRow = if (reuse == null) {
new Row(arity)
} else {
reuse
}
var i: Int = 0
while (i < parsedValues.length) {
reuse.setField(i, parsedValues(i))
i += 1
}
reuseRow
}
@throws[ParseException]
override protected def parseRecord(
holders: Array[AnyRef],
bytes: Array[Byte],
offset: Int,
numBytes: Int)
: Boolean = {
val fieldDelimiter = this.getFieldDelimiter
val fieldIncluded: Array[Boolean] = this.fieldIncluded
var startPos = offset
val limit = offset + numBytes
var field = 0
var output = 0
while (field < fieldIncluded.length) {
// check valid start position
if (startPos >= limit) {
if (isLenient) {
return false
} else {
throw new ParseException("Row too short: " + new String(bytes, offset, numBytes))
}
}
if (fieldIncluded(field)) {
// parse field
val parser: FieldParser[AnyRef] = this.getFieldParsers()(output)
.asInstanceOf[FieldParser[AnyRef]]
val latestValidPos = startPos
startPos = parser.resetErrorStateAndParse(
bytes,
startPos,
limit,
fieldDelimiter,
holders(output))
if (!isLenient && (parser.getErrorState ne ParseErrorState.NONE)) {
// Row is able to handle null values
if (parser.getErrorState ne ParseErrorState.EMPTY_STRING) {
throw new ParseException(s"Parsing error for column $field of row '"
+ new String(bytes, offset, numBytes)
+ s"' originated by ${parser.getClass.getSimpleName}: ${parser.getErrorState}.")
}
}
holders(output) = parser.getLastResult
// check parse result
if (startPos < 0) {
holders(output) = null
startPos = skipFields(bytes, latestValidPos, limit, fieldDelimiter)
}
output += 1
} else {
// skip field
startPos = skipFields(bytes, startPos, limit, fieldDelimiter)
}
// check if something went wrong
if (startPos < 0) {
throw new ParseException(s"Unexpected parser position for column $field of row '"
+ new String(bytes, offset, numBytes) + "'")
}
field += 1
}
true
}
}
object RowCsvInputFormat {
private def extractTypeClasses(rowTypeInfo: RowTypeInfo): Array[Class[_]] = {
val classes = for (i <- 0 until rowTypeInfo.getArity)
yield rowTypeInfo.getTypeAt(i).getTypeClass
classes.toArray
}
}
......@@ -16,9 +16,9 @@
* limitations under the License.
*/
package org.apache.flink.api.table.runtime
package org.apache.flink.api.table.runtime.io
import org.apache.flink.api.common.io.{NonParallelInput, GenericInputFormat}
import org.apache.flink.api.common.io.{GenericInputFormat, NonParallelInput}
import org.apache.flink.api.table.Row
class ValuesInputFormat(val rows: Seq[Row])
......
......@@ -19,14 +19,12 @@
package org.apache.flink.api.table.sources
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.io.TupleCsvInputFormat
import org.apache.flink.api.java.tuple.Tuple
import org.apache.flink.api.java.typeutils.{TupleTypeInfo, TupleTypeInfoBase}
import org.apache.flink.api.java.io.CsvInputFormat
import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
import org.apache.flink.api.table.Row
import org.apache.flink.core.fs.Path
import org.apache.flink.api.table.{Row, TableException}
import org.apache.flink.api.table.runtime.io.RowCsvInputFormat
import org.apache.flink.api.table.typeutils.RowTypeInfo
import org.apache.flink.api.java.io.RowCsvInputFormat
import org.apache.flink.core.fs.Path
/**
* A [[TableSource]] for simple CSV files with a (logically) unlimited number of fields.
......@@ -45,19 +43,23 @@ class CsvTableSource(
path: String,
fieldNames: Array[String],
fieldTypes: Array[TypeInformation[_]],
fieldDelim: String = ",",
rowDelim: String = "\n",
fieldDelim: String = CsvInputFormat.DEFAULT_FIELD_DELIMITER,
rowDelim: String = CsvInputFormat.DEFAULT_LINE_DELIMITER,
quoteCharacter: Character = null,
ignoreFirstLine: Boolean = false,
ignoreComments: String = null,
lenient: Boolean = false)
extends BatchTableSource[Row] {
if (fieldNames.length != fieldTypes.length) {
throw TableException("Number of field names and field types must be equal.")
}
private val returnType = new RowTypeInfo(fieldTypes)
/** Returns the data of the table as a [[DataSet]] of [[Row]]. */
override def getDataSet(execEnv: ExecutionEnvironment): DataSet[Row] = {
val typeInfo = getReturnType.asInstanceOf[RowTypeInfo]
val inputFormat = new RowCsvInputFormat(new Path(path), rowDelim, fieldDelim, typeInfo)
val inputFormat = new RowCsvInputFormat(new Path(path), returnType, rowDelim, fieldDelim)
inputFormat.setSkipFirstLineAsHeader(ignoreFirstLine)
inputFormat.setLenient(lenient)
......@@ -68,7 +70,7 @@ class CsvTableSource(
inputFormat.setCommentPrefix(ignoreComments)
}
execEnv.createInput(inputFormat, typeInfo)
execEnv.createInput(inputFormat, returnType)
}
/** Returns the types of the table fields. */
......@@ -81,7 +83,5 @@ class CsvTableSource(
override def getNumberOfFields: Int = fieldNames.length
/** Returns the [[RowTypeInfo]] for the return type of the [[CsvTableSource]]. */
override def getReturnType: RowTypeInfo = {
new RowTypeInfo(fieldTypes)
}
override def getReturnType: RowTypeInfo = returnType
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册