diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/CoGroupOperatorBase.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/CoGroupOperatorBase.java index 476a6577ad1718847743ffaa60ce191158e70f37..87bcc2d59d0a30b591fe08dfca7b9fe1a31dd0be 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/CoGroupOperatorBase.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/CoGroupOperatorBase.java @@ -19,7 +19,10 @@ package org.apache.flink.api.common.operators.base; +import java.util.List; + import org.apache.flink.api.common.functions.CoGroupFunction; +import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.operators.BinaryOperatorInformation; import org.apache.flink.api.common.operators.DualInputOperator; import org.apache.flink.api.common.operators.Ordering; @@ -152,4 +155,10 @@ public class CoGroupOperatorBase executeOnCollections(List inputData1, List inputData2, RuntimeContext runtimeContext) throws Exception { + // TODO Auto-generated method stub + return null; + } } diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/DeltaIterationBase.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/DeltaIterationBase.java index 7ac17eb8b598f2e7b7f6f4cb763e6b00d6f2bc58..9a834f93ffd84f18cb37c1b328450afb9631ca4e 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/DeltaIterationBase.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/DeltaIterationBase.java @@ -20,10 +20,12 @@ package org.apache.flink.api.common.operators.base; import java.util.Collections; +import java.util.List; import java.util.Map; import org.apache.flink.api.common.aggregators.AggregatorRegistry; import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.operators.BinaryOperatorInformation; import org.apache.flink.api.common.operators.DualInputOperator; import org.apache.flink.api.common.operators.IterationOperator; @@ -329,4 +331,10 @@ public class DeltaIterationBase extends DualInputOperator executeOnCollections(List inputData1, List inputData2, RuntimeContext runtimeContext) throws Exception { + // TODO Auto-generated method stub + return null; + } } diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/JoinOperatorBase.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/JoinOperatorBase.java index 8468ed43157f56d74d6de7b6e0d5a12471334380..34ede6529487ec73835136191be85415cca2c682 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/JoinOperatorBase.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/JoinOperatorBase.java @@ -20,11 +20,26 @@ package org.apache.flink.api.common.operators.base; import org.apache.flink.api.common.functions.FlatJoinFunction; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.api.common.functions.util.FunctionUtils; +import org.apache.flink.api.common.functions.util.ListCollector; import org.apache.flink.api.common.operators.BinaryOperatorInformation; import org.apache.flink.api.common.operators.DualInputOperator; import org.apache.flink.api.common.operators.util.UserCodeClassWrapper; import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; import org.apache.flink.api.common.operators.util.UserCodeWrapper; +import org.apache.flink.api.common.typeinfo.AtomicType; +import org.apache.flink.api.common.typeinfo.CompositeType; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.GenericPairComparator; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.common.typeutils.TypePairComparator; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; /** * @see org.apache.flink.api.common.functions.FlatJoinFunction @@ -34,7 +49,7 @@ public class JoinOperatorBase udf, BinaryOperatorInformation operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) { super(udf, operatorInfo, keyPositions1, keyPositions2, name); } - + public JoinOperatorBase(FT udf, BinaryOperatorInformation operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) { super(new UserCodeObjectWrapper(udf), operatorInfo, keyPositions1, keyPositions2, name); } @@ -42,4 +57,83 @@ public class JoinOperatorBase udf, BinaryOperatorInformation operatorInfo, int[] keyPositions1, int[] keyPositions2, String name) { super(new UserCodeClassWrapper(udf), operatorInfo, keyPositions1, keyPositions2, name); } + + @SuppressWarnings("unchecked") + @Override + protected List executeOnCollections(List inputData1, List inputData2, RuntimeContext runtimeContext) throws Exception { + FlatJoinFunction function = userFunction.getUserCodeObject(); + + FunctionUtils.setFunctionRuntimeContext(function, runtimeContext); + FunctionUtils.openFunction(function, this.parameters); + + TypeInformation leftInformation = getOperatorInfo().getFirstInputType(); + TypeInformation rightInformation = getOperatorInfo().getSecondInputType(); + + TypeComparator leftComparator; + TypeComparator rightComparator; + + if(leftInformation instanceof AtomicType){ + leftComparator = ((AtomicType) leftInformation).createComparator(true); + }else if(leftInformation instanceof CompositeType){ + int[] keyPositions = getKeyColumns(0); + boolean[] orders = new boolean[keyPositions.length]; + Arrays.fill(orders, true); + + leftComparator = ((CompositeType) leftInformation).createComparator(keyPositions, orders); + }else{ + throw new RuntimeException("Type information for left input of type " + leftInformation.getClass() + .getCanonicalName() + " is not supported. Could not generate a comparator."); + } + + if(rightInformation instanceof AtomicType){ + rightComparator = ((AtomicType) rightInformation).createComparator(true); + }else if(rightInformation instanceof CompositeType){ + int[] keyPositions = getKeyColumns(1); + boolean[] orders = new boolean[keyPositions.length]; + Arrays.fill(orders, true); + + rightComparator = ((CompositeType) rightInformation).createComparator(keyPositions, orders); + }else{ + throw new RuntimeException("Type information for right input of type " + rightInformation.getClass() + .getCanonicalName() + " is not supported. Could not generate a comparator."); + } + + TypePairComparator pairComparator = new GenericPairComparator(leftComparator, + rightComparator); + + List result = new ArrayList(); + ListCollector collector = new ListCollector(result); + + Map> probeTable = new HashMap>(); + + //Build probe table + for(IN2 element: inputData2){ + List list = probeTable.get(rightComparator.hash(element)); + if(list == null){ + list = new ArrayList(); + probeTable.put(rightComparator.hash(element), list); + } + + list.add(element); + } + + //Probing + for(IN1 left: inputData1){ + List matchingHashes = probeTable.get(leftComparator.hash(left)); + + pairComparator.setReference(left); + + if(matchingHashes != null){ + for(IN2 right: matchingHashes){ + if(pairComparator.equalToReference(right)){ + function.join(left, right, collector); + } + } + } + } + + FunctionUtils.closeFunction(function); + + return result; + } } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/GenericPairComparator.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/GenericPairComparator.java similarity index 98% rename from flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/GenericPairComparator.java rename to flink-core/src/main/java/org/apache/flink/api/common/typeutils/GenericPairComparator.java index 562c7061f51f9f1c50a740a7415c5ab27d21ca05..b64a3a3bc4590e655d732cb185236ad8bb049167 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/GenericPairComparator.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/GenericPairComparator.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.api.java.typeutils.runtime; +package org.apache.flink.api.common.typeutils; import java.io.Serializable; diff --git a/flink-core/src/test/java/org/apache/flink/api/common/operators/base/JoinOperatorBaseTest.java b/flink-core/src/test/java/org/apache/flink/api/common/operators/base/JoinOperatorBaseTest.java new file mode 100644 index 0000000000000000000000000000000000000000..b048cc5afdd7505a82b730dea108089e09ade44a --- /dev/null +++ b/flink-core/src/test/java/org/apache/flink/api/common/operators/base/JoinOperatorBaseTest.java @@ -0,0 +1,122 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.operators.base; + +import static org.junit.Assert.*; + +import org.apache.flink.api.common.functions.FlatJoinFunction; +import org.apache.flink.api.common.functions.RichFlatJoinFunction; +import org.apache.flink.api.common.functions.util.RuntimeUDFContext; +import org.apache.flink.api.common.operators.BinaryOperatorInformation; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.util.Collector; +import org.junit.Test; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +@SuppressWarnings("serial") +public class JoinOperatorBaseTest implements Serializable { + + @Test + public void testJoinPlain(){ + final FlatJoinFunction joiner = new FlatJoinFunction() { + + @Override + public void join(String first, String second, Collector out) throws Exception { + out.collect(first.length()); + out.collect(second.length()); + } + }; + + @SuppressWarnings({ "rawtypes", "unchecked" }) + JoinOperatorBase > base = new JoinOperatorBase(joiner, + new BinaryOperatorInformation(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO), new int[0], new int[0], "TestJoiner"); + + List inputData1 = new ArrayList(Arrays.asList("foo", "bar", "foobar")); + List inputData2 = new ArrayList(Arrays.asList("foobar", "foo")); + List expected = new ArrayList(Arrays.asList(3, 3, 6 ,6)); + + try { + List result = base.executeOnCollections(inputData1, inputData2, null); + + assertEquals(expected, result); + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testJoinRich(){ + final AtomicBoolean opened = new AtomicBoolean(false); + final AtomicBoolean closed = new AtomicBoolean(false); + final String taskName = "Test rich join function"; + + final RichFlatJoinFunction joiner = new RichFlatJoinFunction() { + @Override + public void open(Configuration parameters) throws Exception { + opened.compareAndSet(false, true); + assertEquals(0, getRuntimeContext().getIndexOfThisSubtask()); + assertEquals(1, getRuntimeContext().getNumberOfParallelSubtasks()); + } + + @Override + public void close() throws Exception{ + closed.compareAndSet(false, true); + } + + @Override + public void join(String first, String second, Collector out) throws Exception { + out.collect(first.length()); + out.collect(second.length()); + } + }; + + JoinOperatorBase> base = new JoinOperatorBase>(joiner, new BinaryOperatorInformation(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.INT_TYPE_INFO), new int[0], new int[0], taskName); + + final List inputData1 = new ArrayList(Arrays.asList("foo", "bar", "foobar")); + final List inputData2 = new ArrayList(Arrays.asList("foobar", "foo")); + final List expected = new ArrayList(Arrays.asList(3, 3, 6, 6)); + + + try { + List result = base.executeOnCollections(inputData1, inputData2, new RuntimeUDFContext(taskName, + 1, 0)); + + assertEquals(expected, result); + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + + assertTrue(opened.get()); + assertTrue(closed.get()); + } +} diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/RuntimePairComparatorFactory.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/RuntimePairComparatorFactory.java index e81be2f9172be63524b2727c6984ed2043034367..eee6643ab5ea7f2347d6affabe0641f6a664aff1 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/RuntimePairComparatorFactory.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/RuntimePairComparatorFactory.java @@ -18,6 +18,7 @@ package org.apache.flink.api.java.typeutils.runtime; +import org.apache.flink.api.common.typeutils.GenericPairComparator; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypePairComparator; import org.apache.flink.api.common.typeutils.TypePairComparatorFactory; diff --git a/flink-java/src/test/java/org/apache/flink/api/java/operators/base/JoinOperatorBaseTest.java b/flink-java/src/test/java/org/apache/flink/api/java/operators/base/JoinOperatorBaseTest.java new file mode 100644 index 0000000000000000000000000000000000000000..f332832ec73eab8a237e067a0d0100be0ccf3567 --- /dev/null +++ b/flink-java/src/test/java/org/apache/flink/api/java/operators/base/JoinOperatorBaseTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.java.operators.base; + +import static org.junit.Assert.*; + +import org.apache.flink.api.common.functions.FlatJoinFunction; +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.api.common.operators.BinaryOperatorInformation; +import org.apache.flink.api.common.operators.base.JoinOperatorBase; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.util.Collector; +import org.junit.Test; + +import java.io.Serializable; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +public class JoinOperatorBaseTest implements Serializable { + + @Test + public void testTupleBaseJoiner(){ + final FlatJoinFunction, Tuple2, Tuple2> joiner = new FlatJoinFunction() { + @Override + public void join(Object first, Object second, Collector out) throws Exception { + Tuple3 fst = (Tuple3)first; + Tuple2 snd = (Tuple2)second; + + assertEquals(fst.f0, snd.f1); + assertEquals(fst.f2, snd.f0); + + out.collect(new Tuple2(fst.f1, snd.f0.toString())); + } + }; + + final TupleTypeInfo> leftTypeInfo = TupleTypeInfo.getBasicTupleTypeInfo + (String.class, Double.class, Integer.class); + final TupleTypeInfo> rightTypeInfo = TupleTypeInfo.getBasicTupleTypeInfo(Integer.class, + String.class); + final TupleTypeInfo> outTypeInfo = TupleTypeInfo.getBasicTupleTypeInfo(Double.class, + String.class); + + final int[] leftKeys = new int[]{0,2}; + final int[] rightKeys = new int[]{1,0}; + + final String taskName = "Collection based tuple joiner"; + + final BinaryOperatorInformation, Tuple2, Tuple2> binaryOpInfo = new BinaryOperatorInformation, Tuple2, Tuple2>(leftTypeInfo, rightTypeInfo, outTypeInfo); + + final JoinOperatorBase, Tuple2, Tuple2, FlatJoinFunction, Tuple2, Tuple2>> base = new JoinOperatorBase, + Tuple2, Tuple2, FlatJoinFunction, + Tuple2, Tuple2>>(joiner, binaryOpInfo, leftKeys, rightKeys, taskName); + + final List > inputData1 = new ArrayList>(Arrays.asList( + new Tuple3("foo", 42.0, 1), + new Tuple3("bar", 1.0, 2), + new Tuple3("bar", 2.0, 3), + new Tuple3("foobar", 3.0, 4), + new Tuple3("bar", 3.0, 3) + )); + + final List> inputData2 = new ArrayList>(Arrays.asList( + new Tuple2(3, "bar"), + new Tuple2(4, "foobar"), + new Tuple2(2, "foo") + )); + final Set> expected = new HashSet>(Arrays.asList( + new Tuple2(2.0, "3"), + new Tuple2(3.0, "3"), + new Tuple2(3.0, "4") + )); + + try { + Method executeOnCollections = base.getClass().getDeclaredMethod("executeOnCollections", List.class, + List.class, RuntimeContext.class); + executeOnCollections.setAccessible(true); + + Object result = executeOnCollections.invoke(base, inputData1, inputData2, null); + + assertEquals(expected, new HashSet>((List>)result)); + + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + + } +} diff --git a/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/GenericPairComparatorTest.java b/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/GenericPairComparatorTest.java index b8734880a23b83b73257b51e8eeb894ef04ad2fe..382b00911a657b8810537ed9b6514149fb63cc03 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/GenericPairComparatorTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/typeutils/runtime/GenericPairComparatorTest.java @@ -18,6 +18,7 @@ package org.apache.flink.api.java.typeutils.runtime; +import org.apache.flink.api.common.typeutils.GenericPairComparator; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.base.DoubleComparator; diff --git a/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/GenericPairComparatorTest.scala b/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/GenericPairComparatorTest.scala index d0be032a53484ccbc7568996480567c9c9b1e1a1..9fb098006edd9625d6845a1fcd99c012db52a9e1 100644 --- a/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/GenericPairComparatorTest.scala +++ b/flink-scala/src/test/scala/org/apache/flink/api/scala/runtime/GenericPairComparatorTest.scala @@ -17,10 +17,10 @@ */ package org.apache.flink.api.scala.runtime -import org.apache.flink.api.common.typeutils.{TypeComparator, TypeSerializer} +import org.apache.flink.api.common.typeutils.{GenericPairComparator, TypeComparator, TypeSerializer} import org.apache.flink.api.common.typeutils.base.{DoubleComparator, DoubleSerializer, IntComparator, IntSerializer} -import org.apache.flink.api.java.typeutils.runtime.{GenericPairComparator, TupleComparator} +import org.apache.flink.api.java.typeutils.runtime.TupleComparator import org.apache.flink.api.scala.runtime.tuple.base.PairComparatorTestBase import org.apache.flink.api.scala.typeutils.CaseClassComparator