提交 471f340d 编写于 作者: T Till Rohrmann 提交者: Stephan Ewen

[FLINK-1110] Started implementing the JoinOperatorBase.

Implemented JoinOperatorBase and test cases.
上级 77ac6c00
......@@ -19,7 +19,10 @@
package org.apache.flink.api.common.operators.base;
import java.util.List;
import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.operators.BinaryOperatorInformation;
import org.apache.flink.api.common.operators.DualInputOperator;
import org.apache.flink.api.common.operators.Ordering;
......@@ -152,4 +155,10 @@ public class CoGroupOperatorBase<IN1, IN2, OUT, FT extends CoGroupFunction<IN1,
public void setCombinableSecond(boolean combinableSecond) {
this.combinableSecond = combinableSecond;
}
@Override
protected List<OUT> executeOnCollections(List<IN1> inputData1, List<IN2> inputData2, RuntimeContext runtimeContext) throws Exception {
// TODO Auto-generated method stub
return null;
}
}
......@@ -20,10 +20,12 @@
package org.apache.flink.api.common.operators.base;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.aggregators.AggregatorRegistry;
import org.apache.flink.api.common.functions.AbstractRichFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.operators.BinaryOperatorInformation;
import org.apache.flink.api.common.operators.DualInputOperator;
import org.apache.flink.api.common.operators.IterationOperator;
......@@ -329,4 +331,10 @@ public class DeltaIterationBase<ST, WT> extends DualInputOperator<ST, WT, ST, Ab
return null;
}
}
@Override
protected List<ST> executeOnCollections(List<ST> inputData1, List<WT> inputData2, RuntimeContext runtimeContext) throws Exception {
// TODO Auto-generated method stub
return null;
}
}
......@@ -20,11 +20,26 @@
package org.apache.flink.api.common.operators.base;
import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.functions.util.ListCollector;
import org.apache.flink.api.common.operators.BinaryOperatorInformation;
import org.apache.flink.api.common.operators.DualInputOperator;
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeinfo.AtomicType;
import org.apache.flink.api.common.typeinfo.CompositeType;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.GenericPairComparator;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypePairComparator;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* @see org.apache.flink.api.common.functions.FlatJoinFunction
......@@ -34,7 +49,7 @@ public class JoinOperatorBase<IN1, IN2, OUT, FT extends FlatJoinFunction<IN1, IN
public JoinOperatorBase(UserCodeWrapper<FT> udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) {
super(udf, operatorInfo, keyPositions1, keyPositions2, name);
}
public JoinOperatorBase(FT udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) {
super(new UserCodeObjectWrapper<FT>(udf), operatorInfo, keyPositions1, keyPositions2, name);
}
......@@ -42,4 +57,83 @@ public class JoinOperatorBase<IN1, IN2, OUT, FT extends FlatJoinFunction<IN1, IN
public JoinOperatorBase(Class<? extends FT> udf, BinaryOperatorInformation<IN1, IN2, OUT> operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) {
super(new UserCodeClassWrapper<FT>(udf), operatorInfo, keyPositions1, keyPositions2, name);
}
@SuppressWarnings("unchecked")
@Override
protected List<OUT> executeOnCollections(List<IN1> inputData1, List<IN2> inputData2, RuntimeContext runtimeContext) throws Exception {
FlatJoinFunction<IN1, IN2, OUT> function = userFunction.getUserCodeObject();
FunctionUtils.setFunctionRuntimeContext(function, runtimeContext);
FunctionUtils.openFunction(function, this.parameters);
TypeInformation<IN1> leftInformation = getOperatorInfo().getFirstInputType();
TypeInformation<IN2> rightInformation = getOperatorInfo().getSecondInputType();
TypeComparator<IN1> leftComparator;
TypeComparator<IN2> rightComparator;
if(leftInformation instanceof AtomicType){
leftComparator = ((AtomicType<IN1>) leftInformation).createComparator(true);
}else if(leftInformation instanceof CompositeType){
int[] keyPositions = getKeyColumns(0);
boolean[] orders = new boolean[keyPositions.length];
Arrays.fill(orders, true);
leftComparator = ((CompositeType<IN1>) leftInformation).createComparator(keyPositions, orders);
}else{
throw new RuntimeException("Type information for left input of type " + leftInformation.getClass()
.getCanonicalName() + " is not supported. Could not generate a comparator.");
}
if(rightInformation instanceof AtomicType){
rightComparator = ((AtomicType<IN2>) rightInformation).createComparator(true);
}else if(rightInformation instanceof CompositeType){
int[] keyPositions = getKeyColumns(1);
boolean[] orders = new boolean[keyPositions.length];
Arrays.fill(orders, true);
rightComparator = ((CompositeType<IN2>) rightInformation).createComparator(keyPositions, orders);
}else{
throw new RuntimeException("Type information for right input of type " + rightInformation.getClass()
.getCanonicalName() + " is not supported. Could not generate a comparator.");
}
TypePairComparator<IN1, IN2> pairComparator = new GenericPairComparator<IN1, IN2>(leftComparator,
rightComparator);
List<OUT> result = new ArrayList<OUT>();
ListCollector<OUT> collector = new ListCollector<OUT>(result);
Map<Integer, List<IN2>> probeTable = new HashMap<Integer, List<IN2>>();
//Build probe table
for(IN2 element: inputData2){
List<IN2> list = probeTable.get(rightComparator.hash(element));
if(list == null){
list = new ArrayList<IN2>();
probeTable.put(rightComparator.hash(element), list);
}
list.add(element);
}
//Probing
for(IN1 left: inputData1){
List<IN2> matchingHashes = probeTable.get(leftComparator.hash(left));
pairComparator.setReference(left);
if(matchingHashes != null){
for(IN2 right: matchingHashes){
if(pairComparator.equalToReference(right)){
function.join(left, right, collector);
}
}
}
}
FunctionUtils.closeFunction(function);
return result;
}
}
......@@ -16,7 +16,7 @@
* limitations under the License.
*/
package org.apache.flink.api.java.typeutils.runtime;
package org.apache.flink.api.common.typeutils;
import java.io.Serializable;
......
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.common.operators.base;
import static org.junit.Assert.*;
import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.RichFlatJoinFunction;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.api.common.operators.BinaryOperatorInformation;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector;
import org.junit.Test;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
@SuppressWarnings("serial")
public class JoinOperatorBaseTest implements Serializable {
@Test
public void testJoinPlain(){
final FlatJoinFunction<String, String, Integer> joiner = new FlatJoinFunction<String, String, Integer>() {
@Override
public void join(String first, String second, Collector<Integer> out) throws Exception {
out.collect(first.length());
out.collect(second.length());
}
};
@SuppressWarnings({ "rawtypes", "unchecked" })
JoinOperatorBase<String, String, Integer,
FlatJoinFunction<String, String,Integer> > base = new JoinOperatorBase(joiner,
new BinaryOperatorInformation(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO,
BasicTypeInfo.INT_TYPE_INFO), new int[0], new int[0], "TestJoiner");
List<String> inputData1 = new ArrayList<String>(Arrays.asList("foo", "bar", "foobar"));
List<String> inputData2 = new ArrayList<String>(Arrays.asList("foobar", "foo"));
List<Integer> expected = new ArrayList<Integer>(Arrays.asList(3, 3, 6 ,6));
try {
List<Integer> result = base.executeOnCollections(inputData1, inputData2, null);
assertEquals(expected, result);
} catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
@Test
public void testJoinRich(){
final AtomicBoolean opened = new AtomicBoolean(false);
final AtomicBoolean closed = new AtomicBoolean(false);
final String taskName = "Test rich join function";
final RichFlatJoinFunction<String, String, Integer> joiner = new RichFlatJoinFunction<String, String, Integer>() {
@Override
public void open(Configuration parameters) throws Exception {
opened.compareAndSet(false, true);
assertEquals(0, getRuntimeContext().getIndexOfThisSubtask());
assertEquals(1, getRuntimeContext().getNumberOfParallelSubtasks());
}
@Override
public void close() throws Exception{
closed.compareAndSet(false, true);
}
@Override
public void join(String first, String second, Collector<Integer> out) throws Exception {
out.collect(first.length());
out.collect(second.length());
}
};
JoinOperatorBase<String, String, Integer,
RichFlatJoinFunction<String, String, Integer>> base = new JoinOperatorBase<String, String, Integer,
RichFlatJoinFunction<String, String, Integer>>(joiner, new BinaryOperatorInformation<String, String,
Integer>(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO,
BasicTypeInfo.INT_TYPE_INFO), new int[0], new int[0], taskName);
final List<String> inputData1 = new ArrayList<String>(Arrays.asList("foo", "bar", "foobar"));
final List<String> inputData2 = new ArrayList<String>(Arrays.asList("foobar", "foo"));
final List<Integer> expected = new ArrayList<Integer>(Arrays.asList(3, 3, 6, 6));
try {
List<Integer> result = base.executeOnCollections(inputData1, inputData2, new RuntimeUDFContext(taskName,
1, 0));
assertEquals(expected, result);
} catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
assertTrue(opened.get());
assertTrue(closed.get());
}
}
......@@ -18,6 +18,7 @@
package org.apache.flink.api.java.typeutils.runtime;
import org.apache.flink.api.common.typeutils.GenericPairComparator;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypePairComparator;
import org.apache.flink.api.common.typeutils.TypePairComparatorFactory;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.operators.base;
import static org.junit.Assert.*;
import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.operators.BinaryOperatorInformation;
import org.apache.flink.api.common.operators.base.JoinOperatorBase;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.util.Collector;
import org.junit.Test;
import java.io.Serializable;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
public class JoinOperatorBaseTest implements Serializable {
@Test
public void testTupleBaseJoiner(){
final FlatJoinFunction<Tuple3<String, Double, Integer>, Tuple2<Integer,
String>, Tuple2<Double, String>> joiner = new FlatJoinFunction() {
@Override
public void join(Object first, Object second, Collector out) throws Exception {
Tuple3<String, Double, Integer> fst = (Tuple3<String, Double, Integer>)first;
Tuple2<Integer, String> snd = (Tuple2<Integer, String>)second;
assertEquals(fst.f0, snd.f1);
assertEquals(fst.f2, snd.f0);
out.collect(new Tuple2<Double, String>(fst.f1, snd.f0.toString()));
}
};
final TupleTypeInfo<Tuple3<String, Double, Integer>> leftTypeInfo = TupleTypeInfo.getBasicTupleTypeInfo
(String.class, Double.class, Integer.class);
final TupleTypeInfo<Tuple2<Integer, String>> rightTypeInfo = TupleTypeInfo.getBasicTupleTypeInfo(Integer.class,
String.class);
final TupleTypeInfo<Tuple2<Double, String>> outTypeInfo = TupleTypeInfo.getBasicTupleTypeInfo(Double.class,
String.class);
final int[] leftKeys = new int[]{0,2};
final int[] rightKeys = new int[]{1,0};
final String taskName = "Collection based tuple joiner";
final BinaryOperatorInformation<Tuple3<String, Double, Integer>, Tuple2<Integer, String>, Tuple2<Double,
String>> binaryOpInfo = new BinaryOperatorInformation<Tuple3<String, Double, Integer>, Tuple2<Integer,
String>, Tuple2<Double, String>>(leftTypeInfo, rightTypeInfo, outTypeInfo);
final JoinOperatorBase<Tuple3<String, Double, Integer>, Tuple2<Integer,
String>, Tuple2<Double, String>, FlatJoinFunction<Tuple3<String, Double, Integer>, Tuple2<Integer,
String>, Tuple2<Double, String>>> base = new JoinOperatorBase<Tuple3<String, Double, Integer>,
Tuple2<Integer, String>, Tuple2<Double, String>, FlatJoinFunction<Tuple3<String, Double, Integer>,
Tuple2<Integer, String>, Tuple2<Double, String>>>(joiner, binaryOpInfo, leftKeys, rightKeys, taskName);
final List<Tuple3<String, Double, Integer> > inputData1 = new ArrayList<Tuple3<String, Double,
Integer>>(Arrays.asList(
new Tuple3<String, Double, Integer>("foo", 42.0, 1),
new Tuple3<String,Double, Integer>("bar", 1.0, 2),
new Tuple3<String, Double, Integer>("bar", 2.0, 3),
new Tuple3<String, Double, Integer>("foobar", 3.0, 4),
new Tuple3<String, Double, Integer>("bar", 3.0, 3)
));
final List<Tuple2<Integer, String>> inputData2 = new ArrayList<Tuple2<Integer, String>>(Arrays.asList(
new Tuple2<Integer, String>(3, "bar"),
new Tuple2<Integer, String>(4, "foobar"),
new Tuple2<Integer, String>(2, "foo")
));
final Set<Tuple2<Double, String>> expected = new HashSet<Tuple2<Double, String>>(Arrays.asList(
new Tuple2<Double, String>(2.0, "3"),
new Tuple2<Double, String>(3.0, "3"),
new Tuple2<Double, String>(3.0, "4")
));
try {
Method executeOnCollections = base.getClass().getDeclaredMethod("executeOnCollections", List.class,
List.class, RuntimeContext.class);
executeOnCollections.setAccessible(true);
Object result = executeOnCollections.invoke(base, inputData1, inputData2, null);
assertEquals(expected, new HashSet<Tuple2<Double, String>>((List<Tuple2<Double, String>>)result));
} catch (Exception e) {
e.printStackTrace();
fail(e.getMessage());
}
}
}
......@@ -18,6 +18,7 @@
package org.apache.flink.api.java.typeutils.runtime;
import org.apache.flink.api.common.typeutils.GenericPairComparator;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.DoubleComparator;
......
......@@ -17,10 +17,10 @@
*/
package org.apache.flink.api.scala.runtime
import org.apache.flink.api.common.typeutils.{TypeComparator, TypeSerializer}
import org.apache.flink.api.common.typeutils.{GenericPairComparator, TypeComparator, TypeSerializer}
import org.apache.flink.api.common.typeutils.base.{DoubleComparator, DoubleSerializer, IntComparator, IntSerializer}
import org.apache.flink.api.java.typeutils.runtime.{GenericPairComparator, TupleComparator}
import org.apache.flink.api.java.typeutils.runtime.TupleComparator
import org.apache.flink.api.scala.runtime.tuple.base.PairComparatorTestBase
import org.apache.flink.api.scala.typeutils.CaseClassComparator
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册