提交 3dc2fe1d 编写于 作者: M Max Michels

[java-api][scala-api] convenience methods count/collect to transfer a DataSet to the client

- this implements two convenience methods on DataSet for the Java and Scala API
- appropriate tests have been added

count(): returns the number of elements in a DataSet
collect(): returns a List<T> with the actual elements of a DataSet<T>

- both methods use accumulators to get the results back to the client
- both methods force an execution of the job to generate the results

This closes #210
上级 facf54ea
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.common.accumulators;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import org.apache.commons.lang3.SerializationUtils;
/**
* * This accumulator stores a collection of objects which are immediately serialized to cope with object reuse.
* * When the objects are requested again, they are deserialized.
* @param <T> The type of the accumulated objects
*/
public class ListAccumulator<T> implements Accumulator<T, ArrayList<T>> {
private static final long serialVersionUID = 1L;
private ArrayList<byte[]> localValue = new ArrayList<byte[]>();
@Override
public void add(T value) {
byte[] byteArray = SerializationUtils.serialize((Serializable) value);
localValue.add(byteArray);
}
@Override
public ArrayList<T> getLocalValue() {
ArrayList<T> arrList = new ArrayList<T>();
for (byte[] byteArr : localValue) {
T item = SerializationUtils.deserialize(byteArr);
arrList.add(item);
}
return arrList;
}
@Override
public void resetLocal() {
localValue.clear();
}
@Override
public void merge(Accumulator<T, ArrayList<T>> other) {
localValue.addAll(((ListAccumulator<T>) other).localValue);
}
@Override
public Accumulator<T, ArrayList<T>> clone() {
ListAccumulator<T> newInstance = new ListAccumulator<T>();
for (byte[] item : localValue) {
newInstance.localValue.add(item.clone());
}
return newInstance;
}
@Override
public void write(ObjectOutputStream out) throws IOException {
int numItems = localValue.size();
out.writeInt(numItems);
for (byte[] item : localValue) {
out.writeInt(item.length);
out.write(item);
}
}
@Override
public void read(ObjectInputStream in) throws IOException {
int numItems = in.readInt();
for (int i = 0; i < numItems; i++) {
int len = in.readInt();
byte[] obj = new byte[len];
in.read(obj);
localValue.add(obj);
}
}
}
......@@ -18,8 +18,11 @@
package org.apache.flink.api.java;
import java.util.List;
import org.apache.commons.lang3.Validate;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.JobExecutionResult;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
......@@ -36,11 +39,13 @@ import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint;
import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint;
import org.apache.flink.api.common.operators.base.PartitionOperatorBase.PartitionMethod;
import org.apache.flink.api.java.aggregation.Aggregations;
import org.apache.flink.api.java.functions.FirstReducer;
import org.apache.flink.api.java.functions.FormattingMapper;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.functions.SelectByMaxFunction;
import org.apache.flink.api.java.functions.SelectByMinFunction;
import org.apache.flink.api.java.io.CsvOutputFormat;
import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.api.java.io.PrintingOutputFormat;
import org.apache.flink.api.java.io.TextOutputFormat;
import org.apache.flink.api.java.io.TextOutputFormat.TextFormatter;
......@@ -54,7 +59,6 @@ import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.operators.DistinctOperator;
import org.apache.flink.api.java.operators.FilterOperator;
import org.apache.flink.api.java.operators.ProjectOperator;
import org.apache.flink.api.java.functions.FirstReducer;
import org.apache.flink.api.java.operators.FlatMapOperator;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.IterativeDataSet;
......@@ -78,6 +82,7 @@ import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.core.fs.FileSystem.WriteMode;
import org.apache.flink.core.fs.Path;
import org.apache.flink.util.AbstractID;
import com.google.common.base.Preconditions;
......@@ -371,7 +376,45 @@ public abstract class DataSet<T> {
public AggregateOperator<T> min(int field) {
return aggregate(Aggregations.MIN, field);
}
/**
* Convenience method to get the count (number of elements) of a DataSet
*
* @return A long integer that represents the number of elements in the set
*
* @see org.apache.flink.api.java.Utils.CountHelper
*/
public long count() throws Exception {
final String id = new AbstractID().toString();
flatMap(new Utils.CountHelper<T>(id)).output(
new DiscardingOutputFormat<Long>());
JobExecutionResult res = getExecutionEnvironment().execute();
return res.<Long> getAccumulatorResult(id);
}
/* Convenience method to get the elements of a DataSet as a List
* As DataSet can contain a lot of data, this method should be used with caution.
*
* @return A List containing the elements of the DataSet
*
* @see org.apache.flink.api.java.Utils.CollectHelper
*/
public List<T> collect() throws Exception {
final String id = new AbstractID().toString();
this.flatMap(new Utils.CollectHelper<T>(id)).output(
new DiscardingOutputFormat<T>());
JobExecutionResult res = this.getExecutionEnvironment().execute();
return (List<T>) res.getAccumulatorResult(id);
}
/**
* Applies a Reduce transformation on a non-grouped {@link DataSet}.<br/>
* The transformation consecutively calls a {@link org.apache.flink.api.common.functions.RichReduceFunction}
......
......@@ -23,6 +23,11 @@ import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.java.typeutils.GenericTypeInfo;
import java.util.List;
import org.apache.flink.api.common.accumulators.ListAccumulator;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector;
public class Utils {
......@@ -60,4 +65,51 @@ public class Utils {
}
}
}
public static class CountHelper<T> extends RichFlatMapFunction<T, Long> {
private static final long serialVersionUID = 1L;
private final String id;
private long counter;
public CountHelper(String id) {
this.id = id;
this.counter = 0L;
}
@Override
public void flatMap(T value, Collector<Long> out) throws Exception {
counter++;
}
@Override
public void close() throws Exception {
getRuntimeContext().getLongCounter(id).add(counter);
}
}
public static class CollectHelper<T> extends RichFlatMapFunction<T, T> {
private static final long serialVersionUID = 1L;
private final String id;
private final ListAccumulator<T> accumulator;
public CollectHelper(String id) {
this.id = id;
this.accumulator = new ListAccumulator<T>();
}
@Override
public void open(Configuration parameters) throws Exception {
getRuntimeContext().addAccumulator(id, accumulator);
}
@Override
public void flatMap(T value, Collector<T> out) throws Exception {
accumulator.add(value);
}
}
}
......@@ -17,6 +17,7 @@
*/
package org.apache.flink.api.scala
import java.lang
import org.apache.commons.lang3.Validate
import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.common.aggregators.Aggregator
......@@ -24,22 +25,24 @@ import org.apache.flink.api.common.functions._
import org.apache.flink.api.common.io.{FileOutputFormat, OutputFormat}
import org.apache.flink.api.common.operators.Order
import org.apache.flink.api.common.operators.base.PartitionOperatorBase.PartitionMethod
import org.apache.flink.api.java.Utils.CountHelper
import org.apache.flink.api.java.aggregation.Aggregations
import org.apache.flink.api.java.functions.{FirstReducer, KeySelector}
import org.apache.flink.api.java.io.{PrintingOutputFormat, TextOutputFormat}
import org.apache.flink.api.java.io.{DiscardingOutputFormat, PrintingOutputFormat, TextOutputFormat}
import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint
import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint
import org.apache.flink.api.java.operators.Keys.ExpressionKeys
import org.apache.flink.api.java.operators._
import org.apache.flink.api.java.{DataSet => JavaDataSet, SortPartitionOperator}
import org.apache.flink.api.java.{DataSet => JavaDataSet, SortPartitionOperator, Utils}
import org.apache.flink.api.scala.operators.{ScalaCsvOutputFormat, ScalaAggregateOperator}
import org.apache.flink.configuration.Configuration
import org.apache.flink.core.fs.{FileSystem, Path}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.util.Collector
import org.apache.flink.util.{AbstractID, Collector}
import scala.collection.JavaConverters._
import scala.reflect.ClassTag
/**
* The DataSet, the basic abstraction of Flink. This represents a collection of elements of a
* specific type `T`. The operations in this class can be used to create new DataSets and to combine
......@@ -508,6 +511,37 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
aggregate(Aggregations.MIN, field)
}
/**
* Convenience method to get the count (number of elements) of a DataSet
*
* @return A long integer that represents the number of elements in the set
*
* @see org.apache.flink.api.java.Utils.CountHelper
*/
@throws(classOf[Exception])
def count: Long = {
val id = new AbstractID().toString
javaSet.flatMap(new CountHelper[T](id)).output(new DiscardingOutputFormat[lang.Long])
val res = getExecutionEnvironment.execute()
res.getAccumulatorResult[Long](id)
}
/**
* Convenience method to get the elements of a DataSet as a List
* As DataSet can contain a lot of data, this method should be used with caution.
*
* @return A List containing the elements of the DataSet
*
* @see org.apache.flink.api.java.Utils.CollectHelper
*/
@throws(classOf[Exception])
def collect: List[T] = {
val id = new AbstractID().toString
javaSet.flatMap(new Utils.CollectHelper[T](id)).output(new DiscardingOutputFormat[T])
val res = getExecutionEnvironment.execute()
res.getAccumulatorResult(id).asInstanceOf[List[T]]
}
/**
* Creates a new [[DataSet]] by merging the elements of this DataSet using an associative reduce
* function.
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.convenience;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import static org.junit.Assert.*;
import org.apache.flink.test.iterative.nephele.danglingpagerank.BooleanValue;
import org.junit.Test;
public class CountCollectITCase {
@Test
public void testSimple() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setDegreeOfParallelism(5);
Integer[] input = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
DataSet<Integer> data = env.fromElements(input);
// count
long numEntries = data.count();
assertEquals(10, numEntries);
// collect
ArrayList<Integer> list = (ArrayList<Integer>) data.collect();
assertArrayEquals(input, list.toArray());
}
@Test
public void testAdvanced() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setDegreeOfParallelism(5);
env.getConfig().disableObjectReuse();
DataSet<Integer> data = env.fromElements(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
DataSet<Integer> data2 = env.fromElements(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
DataSet<Tuple2<Integer, Integer>> data3 = data.cross(data2);
// count
long numEntries = data3.count();
assertEquals(100, numEntries);
// collect
ArrayList<Tuple2<Integer, Integer>> list = (ArrayList<Tuple2<Integer, Integer>>) data3.collect();
System.out.println(list);
// set expected entries in a hash map to true
HashMap<Tuple2<Integer, Integer>, Boolean> expected = new HashMap<Tuple2<Integer, Integer>, Boolean>();
for (int i = 1; i <= 10; i++) {
for (int j = 1; j <= 10; j++) {
expected.put(new Tuple2<Integer, Integer>(i, j), true);
}
}
// check if all entries are contained in the hash map
for (int i = 0; i < 100; i++) {
Tuple2<Integer, Integer> element = list.get(i);
assertEquals(expected.get(element), true);
expected.remove(element);
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册