提交 4dbb10f5 编写于 作者: T twalthr 提交者: Fabian Hueske

[FLINK-3108] [java] JoinOperator's with() calls the wrong TypeExtractor method

This closes #1440
上级 9547f08a
......@@ -647,7 +647,8 @@ public class CoGroupOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1, I2, OU
if (function == null) {
throw new NullPointerException("CoGroup function must not be null.");
}
TypeInformation<R> returnType = TypeExtractor.getCoGroupReturnTypes(function, input1.getType(), input2.getType());
TypeInformation<R> returnType = TypeExtractor.getCoGroupReturnTypes(function, input1.getType(), input2.getType(),
Utils.getCallLocationName(), true);
return new CoGroupOperator<I1, I2, R>(input1, input2, keys1, keys2, input1.clean(function), returnType,
groupSortKeyOrderFirst, groupSortKeyOrderSecond,
......
......@@ -559,16 +559,16 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
if (function == null) {
throw new NullPointerException("Join function must not be null.");
}
TypeInformation<R> returnType = TypeExtractor.getFlatJoinReturnTypes(function, getInput1Type(), getInput2Type());
TypeInformation<R> returnType = TypeExtractor.getFlatJoinReturnTypes(function, getInput1Type(), getInput2Type(), Utils.getCallLocationName(), true);
return new EquiJoin<>(getInput1(), getInput2(), getKeys1(), getKeys2(), clean(function), returnType, getJoinHint(), Utils.getCallLocationName(), joinType);
}
public <R> EquiJoin<I1, I2, R> with (JoinFunction<I1, I2, R> function) {
public <R> EquiJoin<I1, I2, R> with(JoinFunction<I1, I2, R> function) {
if (function == null) {
throw new NullPointerException("Join function must not be null.");
}
FlatJoinFunction<I1, I2, R> generatedFunction = new WrappingFlatJoinFunction<>(clean(function));
TypeInformation<R> returnType = TypeExtractor.getJoinReturnTypes(function, getInput1Type(), getInput2Type());
TypeInformation<R> returnType = TypeExtractor.getJoinReturnTypes(function, getInput1Type(), getInput2Type(), Utils.getCallLocationName(), true);
return new EquiJoin<>(getInput1(), getInput2(), getKeys1(), getKeys2(), generatedFunction, function, returnType, getJoinHint(), Utils.getCallLocationName(), joinType);
}
......@@ -582,7 +582,7 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
@Override
public void join(IN1 left, IN2 right, Collector<OUT> out) throws Exception {
out.collect (this.wrappedFunction.join(left, right));
out.collect(this.wrappedFunction.join(left, right));
}
}
......
......@@ -162,8 +162,8 @@ public class SortedGrouping<T> extends Grouping<T> {
throw new NullPointerException("GroupReduce function must not be null.");
}
TypeInformation<R> resultType = TypeExtractor.getGroupReduceReturnTypes(reducer,
this.getDataSet().getType());
return new GroupReduceOperator<T, R>(this, resultType, dataSet.clean(reducer), Utils.getCallLocationName() );
this.getDataSet().getType(), Utils.getCallLocationName(), true);
return new GroupReduceOperator<T, R>(this, resultType, dataSet.clean(reducer), Utils.getCallLocationName());
}
/**
......@@ -182,7 +182,8 @@ public class SortedGrouping<T> extends Grouping<T> {
if (combiner == null) {
throw new NullPointerException("GroupCombine function must not be null.");
}
TypeInformation<R> resultType = TypeExtractor.getGroupCombineReturnTypes(combiner, this.getDataSet().getType());
TypeInformation<R> resultType = TypeExtractor.getGroupCombineReturnTypes(combiner,
this.getDataSet().getType(), Utils.getCallLocationName(), true);
return new GroupCombineOperator<T, R>(this, resultType, dataSet.clean(combiner), Utils.getCallLocationName());
}
......
......@@ -156,7 +156,8 @@ public class UnsortedGrouping<T> extends Grouping<T> {
if (reducer == null) {
throw new NullPointerException("GroupReduce function must not be null.");
}
TypeInformation<R> resultType = TypeExtractor.getGroupReduceReturnTypes(reducer, this.getDataSet().getType());
TypeInformation<R> resultType = TypeExtractor.getGroupReduceReturnTypes(reducer,
this.getDataSet().getType(), Utils.getCallLocationName(), true);
return new GroupReduceOperator<T, R>(this, resultType, dataSet.clean(reducer), Utils.getCallLocationName());
}
......@@ -177,7 +178,8 @@ public class UnsortedGrouping<T> extends Grouping<T> {
if (combiner == null) {
throw new NullPointerException("GroupCombine function must not be null.");
}
TypeInformation<R> resultType = TypeExtractor.getGroupCombineReturnTypes(combiner, this.getDataSet().getType());
TypeInformation<R> resultType = TypeExtractor.getGroupCombineReturnTypes(combiner,
this.getDataSet().getType(), Utils.getCallLocationName(), true);
return new GroupCombineOperator<T, R>(this, resultType, dataSet.clean(combiner), Utils.getCallLocationName());
}
......
......@@ -24,8 +24,14 @@ import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupCombineFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
......@@ -42,7 +48,7 @@ import org.junit.runners.Parameterized.Parameters;
@RunWith(Parameterized.class)
public class TypeHintITCase extends JavaProgramTestBase {
private static int NUM_PROGRAMS = 3;
private static int NUM_PROGRAMS = 9;
private int curProgId = config.getInteger("ProgramId", -1);
......@@ -114,9 +120,9 @@ public class TypeHintITCase extends JavaProgramTestBase {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Integer> identityMapDs = ds.
flatMap(new FlatMapper<Tuple3<Integer, Long, String>, Integer>())
.returns(Integer.class);
DataSet<Integer> identityMapDs = ds
.flatMap(new FlatMapper<Tuple3<Integer, Long, String>, Integer>())
.returns(Integer.class);
List<Integer> result = identityMapDs.collect();
String expectedResult = "2\n" +
......@@ -126,6 +132,124 @@ public class TypeHintITCase extends JavaProgramTestBase {
compareResultAsText(result, expectedResult);
break;
}
// Test join with type information type hint
case 4: {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Tuple3<Integer, Long, String>> ds2 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Integer> resultDs = ds1
.join(ds2)
.where(0)
.equalTo(0)
.with(new Joiner<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>, Integer>())
.returns(BasicTypeInfo.INT_TYPE_INFO);
List<Integer> result = resultDs.collect();
String expectedResult = "2\n" +
"3\n" +
"1\n";
compareResultAsText(result, expectedResult);
break;
}
// Test flat join with type information type hint
case 5: {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Tuple3<Integer, Long, String>> ds2 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Integer> resultDs = ds1
.join(ds2)
.where(0)
.equalTo(0)
.with(new FlatJoiner<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>, Integer>())
.returns(BasicTypeInfo.INT_TYPE_INFO);
List<Integer> result = resultDs.collect();
String expectedResult = "2\n" +
"3\n" +
"1\n";
compareResultAsText(result, expectedResult);
break;
}
// Test unsorted group reduce with type information type hint
case 6: {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Integer> resultDs = ds
.groupBy(0)
.reduceGroup(new GroupReducer<Tuple3<Integer, Long, String>, Integer>())
.returns(BasicTypeInfo.INT_TYPE_INFO);
List<Integer> result = resultDs.collect();
String expectedResult = "2\n" +
"3\n" +
"1\n";
compareResultAsText(result, expectedResult);
break;
}
// Test sorted group reduce with type information type hint
case 7: {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Integer> resultDs = ds
.groupBy(0)
.sortGroup(0, Order.ASCENDING)
.reduceGroup(new GroupReducer<Tuple3<Integer, Long, String>, Integer>())
.returns(BasicTypeInfo.INT_TYPE_INFO);
List<Integer> result = resultDs.collect();
String expectedResult = "2\n" +
"3\n" +
"1\n";
compareResultAsText(result, expectedResult);
break;
}
// Test combine group with type information type hint
case 8: {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Integer> resultDs = ds
.groupBy(0)
.combineGroup(new GroupCombiner<Tuple3<Integer, Long, String>, Integer>())
.returns(BasicTypeInfo.INT_TYPE_INFO);
List<Integer> result = resultDs.collect();
String expectedResult = "2\n" +
"3\n" +
"1\n";
compareResultAsText(result, expectedResult);
break;
}
// Test cogroup with type information type hint
case 9: {
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Tuple3<Integer, Long, String>> ds2 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Integer> resultDs = ds1
.coGroup(ds2)
.where(0)
.equalTo(0)
.with(new CoGrouper<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>, Integer>())
.returns(BasicTypeInfo.INT_TYPE_INFO);
List<Integer> result = resultDs.collect();
String expectedResult = "2\n" +
"3\n" +
"1\n";
compareResultAsText(result, expectedResult);
break;
}
default:
throw new IllegalArgumentException("Invalid program id");
}
......@@ -154,4 +278,49 @@ public class TypeHintITCase extends JavaProgramTestBase {
}
}
public static class Joiner<IN1, IN2, OUT> implements JoinFunction<IN1, IN2, OUT> {
private static final long serialVersionUID = 1L;
@Override
public OUT join(IN1 first, IN2 second) throws Exception {
return (OUT) ((Tuple3) first).f0;
}
}
public static class FlatJoiner<IN1, IN2, OUT> implements FlatJoinFunction<IN1, IN2, OUT> {
private static final long serialVersionUID = 1L;
@Override
public void join(IN1 first, IN2 second, Collector<OUT> out) throws Exception {
out.collect((OUT) ((Tuple3) first).f0);
}
}
public static class GroupReducer<IN, OUT> implements GroupReduceFunction<IN, OUT> {
private static final long serialVersionUID = 1L;
@Override
public void reduce(Iterable<IN> values, Collector<OUT> out) throws Exception {
out.collect((OUT) ((Tuple3) values.iterator().next()).f0);
}
}
public static class GroupCombiner<IN, OUT> implements GroupCombineFunction<IN, OUT> {
private static final long serialVersionUID = 1L;
@Override
public void combine(Iterable<IN> values, Collector<OUT> out) throws Exception {
out.collect((OUT) ((Tuple3) values.iterator().next()).f0);
}
}
public static class CoGrouper<IN1, IN2, OUT> implements CoGroupFunction<IN1, IN2, OUT> {
private static final long serialVersionUID = 1L;
@Override
public void coGroup(Iterable<IN1> first, Iterable<IN2> second, Collector<OUT> out) throws Exception {
out.collect((OUT) ((Tuple3) first.iterator().next()).f0);
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册