diff --git a/flink-core/src/main/java/org/apache/flink/api/common/accumulators/ListAccumulator.java b/flink-core/src/main/java/org/apache/flink/api/common/accumulators/ListAccumulator.java new file mode 100644 index 0000000000000000000000000000000000000000..5a973ac3887a820253a6bb6f8f3386c9e2d519a5 --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/common/accumulators/ListAccumulator.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.flink.api.common.accumulators; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.util.ArrayList; + +import org.apache.commons.lang3.SerializationUtils; + +/** + * * This accumulator stores a collection of objects which are immediately serialized to cope with object reuse. + * * When the objects are requested again, they are deserialized. + * @param The type of the accumulated objects + */ +public class ListAccumulator implements Accumulator> { + + private static final long serialVersionUID = 1L; + + private ArrayList localValue = new ArrayList(); + + @Override + public void add(T value) { + byte[] byteArray = SerializationUtils.serialize((Serializable) value); + localValue.add(byteArray); + } + + @Override + public ArrayList getLocalValue() { + ArrayList arrList = new ArrayList(); + for (byte[] byteArr : localValue) { + T item = SerializationUtils.deserialize(byteArr); + arrList.add(item); + } + return arrList; + } + + @Override + public void resetLocal() { + localValue.clear(); + } + + @Override + public void merge(Accumulator> other) { + localValue.addAll(((ListAccumulator) other).localValue); + } + + @Override + public Accumulator> clone() { + ListAccumulator newInstance = new ListAccumulator(); + for (byte[] item : localValue) { + newInstance.localValue.add(item.clone()); + } + return newInstance; + } + + @Override + public void write(ObjectOutputStream out) throws IOException { + int numItems = localValue.size(); + out.writeInt(numItems); + for (byte[] item : localValue) { + out.writeInt(item.length); + out.write(item); + } + } + + @Override + public void read(ObjectInputStream in) throws IOException { + int numItems = in.readInt(); + for (int i = 0; i < numItems; i++) { + int len = in.readInt(); + byte[] obj = new byte[len]; + in.read(obj); + localValue.add(obj); + } + } + +} diff --git a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java b/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java index 327a15af065a36664b43768dacb50b3d401e8657..d7e6e942c818932b1f71f345b2ada13a5b29c5f8 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java @@ -18,8 +18,11 @@ package org.apache.flink.api.java; +import java.util.List; + import org.apache.commons.lang3.Validate; import org.apache.flink.api.common.InvalidProgramException; +import org.apache.flink.api.common.JobExecutionResult; import org.apache.flink.api.common.functions.FilterFunction; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.GroupReduceFunction; @@ -36,11 +39,13 @@ import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint; import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint; import org.apache.flink.api.common.operators.base.PartitionOperatorBase.PartitionMethod; import org.apache.flink.api.java.aggregation.Aggregations; +import org.apache.flink.api.java.functions.FirstReducer; import org.apache.flink.api.java.functions.FormattingMapper; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.functions.SelectByMaxFunction; import org.apache.flink.api.java.functions.SelectByMinFunction; import org.apache.flink.api.java.io.CsvOutputFormat; +import org.apache.flink.api.java.io.DiscardingOutputFormat; import org.apache.flink.api.java.io.PrintingOutputFormat; import org.apache.flink.api.java.io.TextOutputFormat; import org.apache.flink.api.java.io.TextOutputFormat.TextFormatter; @@ -54,7 +59,6 @@ import org.apache.flink.api.java.operators.DeltaIteration; import org.apache.flink.api.java.operators.DistinctOperator; import org.apache.flink.api.java.operators.FilterOperator; import org.apache.flink.api.java.operators.ProjectOperator; -import org.apache.flink.api.java.functions.FirstReducer; import org.apache.flink.api.java.operators.FlatMapOperator; import org.apache.flink.api.java.operators.GroupReduceOperator; import org.apache.flink.api.java.operators.IterativeDataSet; @@ -78,6 +82,7 @@ import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.core.fs.FileSystem.WriteMode; import org.apache.flink.core.fs.Path; +import org.apache.flink.util.AbstractID; import com.google.common.base.Preconditions; @@ -371,7 +376,45 @@ public abstract class DataSet { public AggregateOperator min(int field) { return aggregate(Aggregations.MIN, field); } - + + /** + * Convenience method to get the count (number of elements) of a DataSet + * + * @return A long integer that represents the number of elements in the set + * + * @see org.apache.flink.api.java.Utils.CountHelper + */ + public long count() throws Exception { + + final String id = new AbstractID().toString(); + + flatMap(new Utils.CountHelper(id)).output( + new DiscardingOutputFormat()); + + JobExecutionResult res = getExecutionEnvironment().execute(); + return res. getAccumulatorResult(id); + } + + + /* Convenience method to get the elements of a DataSet as a List + * As DataSet can contain a lot of data, this method should be used with caution. + * + * @return A List containing the elements of the DataSet + * + * @see org.apache.flink.api.java.Utils.CollectHelper + */ + public List collect() throws Exception { + + final String id = new AbstractID().toString(); + + this.flatMap(new Utils.CollectHelper(id)).output( + new DiscardingOutputFormat()); + + JobExecutionResult res = this.getExecutionEnvironment().execute(); + + return (List) res.getAccumulatorResult(id); + } + /** * Applies a Reduce transformation on a non-grouped {@link DataSet}.
* The transformation consecutively calls a {@link org.apache.flink.api.common.functions.RichReduceFunction} diff --git a/flink-java/src/main/java/org/apache/flink/api/java/Utils.java b/flink-java/src/main/java/org/apache/flink/api/java/Utils.java index 21df16895d9d91e24eac9eb7ecd8907a87304889..82cd0bc1503eee667ab51690dee15424ca180820 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/Utils.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/Utils.java @@ -23,6 +23,11 @@ import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.java.typeutils.GenericTypeInfo; import java.util.List; +import org.apache.flink.api.common.accumulators.ListAccumulator; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.util.Collector; + public class Utils { @@ -60,4 +65,51 @@ public class Utils { } } } + + public static class CountHelper extends RichFlatMapFunction { + + private static final long serialVersionUID = 1L; + + private final String id; + private long counter; + + public CountHelper(String id) { + this.id = id; + this.counter = 0L; + } + + @Override + public void flatMap(T value, Collector out) throws Exception { + counter++; + } + + @Override + public void close() throws Exception { + getRuntimeContext().getLongCounter(id).add(counter); + } + } + + public static class CollectHelper extends RichFlatMapFunction { + + private static final long serialVersionUID = 1L; + + private final String id; + private final ListAccumulator accumulator; + + public CollectHelper(String id) { + this.id = id; + this.accumulator = new ListAccumulator(); + } + + @Override + public void open(Configuration parameters) throws Exception { + getRuntimeContext().addAccumulator(id, accumulator); + } + + @Override + public void flatMap(T value, Collector out) throws Exception { + accumulator.add(value); + } + } + } diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala index 8a1dc41e0744358508672153f53471c1b9d045e6..b93b3f7e2e85393d6e722b7a97077316cd0b756c 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala @@ -17,6 +17,7 @@ */ package org.apache.flink.api.scala +import java.lang import org.apache.commons.lang3.Validate import org.apache.flink.api.common.InvalidProgramException import org.apache.flink.api.common.aggregators.Aggregator @@ -24,22 +25,24 @@ import org.apache.flink.api.common.functions._ import org.apache.flink.api.common.io.{FileOutputFormat, OutputFormat} import org.apache.flink.api.common.operators.Order import org.apache.flink.api.common.operators.base.PartitionOperatorBase.PartitionMethod +import org.apache.flink.api.java.Utils.CountHelper import org.apache.flink.api.java.aggregation.Aggregations import org.apache.flink.api.java.functions.{FirstReducer, KeySelector} -import org.apache.flink.api.java.io.{PrintingOutputFormat, TextOutputFormat} +import org.apache.flink.api.java.io.{DiscardingOutputFormat, PrintingOutputFormat, TextOutputFormat} import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint import org.apache.flink.api.java.operators.Keys.ExpressionKeys import org.apache.flink.api.java.operators._ -import org.apache.flink.api.java.{DataSet => JavaDataSet, SortPartitionOperator} +import org.apache.flink.api.java.{DataSet => JavaDataSet, SortPartitionOperator, Utils} import org.apache.flink.api.scala.operators.{ScalaCsvOutputFormat, ScalaAggregateOperator} import org.apache.flink.configuration.Configuration import org.apache.flink.core.fs.{FileSystem, Path} import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.util.Collector +import org.apache.flink.util.{AbstractID, Collector} import scala.collection.JavaConverters._ import scala.reflect.ClassTag + /** * The DataSet, the basic abstraction of Flink. This represents a collection of elements of a * specific type `T`. The operations in this class can be used to create new DataSets and to combine @@ -508,6 +511,37 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { aggregate(Aggregations.MIN, field) } + /** + * Convenience method to get the count (number of elements) of a DataSet + * + * @return A long integer that represents the number of elements in the set + * + * @see org.apache.flink.api.java.Utils.CountHelper + */ + @throws(classOf[Exception]) + def count: Long = { + val id = new AbstractID().toString + javaSet.flatMap(new CountHelper[T](id)).output(new DiscardingOutputFormat[lang.Long]) + val res = getExecutionEnvironment.execute() + res.getAccumulatorResult[Long](id) + } + + /** + * Convenience method to get the elements of a DataSet as a List + * As DataSet can contain a lot of data, this method should be used with caution. + * + * @return A List containing the elements of the DataSet + * + * @see org.apache.flink.api.java.Utils.CollectHelper + */ + @throws(classOf[Exception]) + def collect: List[T] = { + val id = new AbstractID().toString + javaSet.flatMap(new Utils.CollectHelper[T](id)).output(new DiscardingOutputFormat[T]) + val res = getExecutionEnvironment.execute() + res.getAccumulatorResult(id).asInstanceOf[List[T]] + } + /** * Creates a new [[DataSet]] by merging the elements of this DataSet using an associative reduce * function. diff --git a/flink-tests/src/test/java/org/apache/flink/test/convenience/CountCollectITCase.java b/flink-tests/src/test/java/org/apache/flink/test/convenience/CountCollectITCase.java new file mode 100644 index 0000000000000000000000000000000000000000..a306c9a125369054fc7bfb9517fe8b315851bbc9 --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/convenience/CountCollectITCase.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.convenience; + +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.tuple.Tuple2; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; + +import static org.junit.Assert.*; + +import org.apache.flink.test.iterative.nephele.danglingpagerank.BooleanValue; +import org.junit.Test; + + +public class CountCollectITCase { + + @Test + public void testSimple() throws Exception { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setDegreeOfParallelism(5); + + Integer[] input = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; + + DataSet data = env.fromElements(input); + + // count + long numEntries = data.count(); + assertEquals(10, numEntries); + + // collect + ArrayList list = (ArrayList) data.collect(); + assertArrayEquals(input, list.toArray()); + + } + + @Test + public void testAdvanced() throws Exception { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setDegreeOfParallelism(5); + env.getConfig().disableObjectReuse(); + + + DataSet data = env.fromElements(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + DataSet data2 = env.fromElements(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + + DataSet> data3 = data.cross(data2); + + // count + long numEntries = data3.count(); + assertEquals(100, numEntries); + + // collect + ArrayList> list = (ArrayList>) data3.collect(); + System.out.println(list); + + // set expected entries in a hash map to true + HashMap, Boolean> expected = new HashMap, Boolean>(); + for (int i = 1; i <= 10; i++) { + for (int j = 1; j <= 10; j++) { + expected.put(new Tuple2(i, j), true); + } + } + + // check if all entries are contained in the hash map + for (int i = 0; i < 100; i++) { + Tuple2 element = list.get(i); + assertEquals(expected.get(element), true); + expected.remove(element); + } + + } +}