提交 358259d2 编写于 作者: S Sachin Goel 提交者: Stephan Ewen

[FLINK-2458] [FLINK-2449] [runtime] Access distributed cache entries from...

[FLINK-2458] [FLINK-2449] [runtime] Access distributed cache entries from Iteration contexts & use of distributed cache from Collection Environments

This closes #970
上级 0a7cc023
......@@ -57,17 +57,6 @@ public abstract class AbstractRuntimeUDFContext implements RuntimeContext {
private final DistributedCache distributedCache;
public AbstractRuntimeUDFContext(String name,
int numParallelSubtasks, int subtaskIndex,
ClassLoader userCodeClassLoader,
ExecutionConfig executionConfig,
Map<String, Accumulator<?,?>> accumulators)
{
this(name, numParallelSubtasks, subtaskIndex, userCodeClassLoader, executionConfig,
accumulators, Collections.<String, Future<Path>>emptyMap());
}
public AbstractRuntimeUDFContext(String name,
int numParallelSubtasks, int subtaskIndex,
ClassLoader userCodeClassLoader,
......@@ -79,7 +68,7 @@ public abstract class AbstractRuntimeUDFContext implements RuntimeContext {
this.subtaskIndex = subtaskIndex;
this.userCodeClassLoader = userCodeClassLoader;
this.executionConfig = executionConfig;
this.distributedCache = new DistributedCache(cpTasks);
this.distributedCache = new DistributedCache(Preconditions.checkNotNull(cpTasks));
this.accumulators = Preconditions.checkNotNull(accumulators);
}
......
......@@ -37,18 +37,11 @@ public class RuntimeUDFContext extends AbstractRuntimeUDFContext {
private final HashMap<String, Object> initializedBroadcastVars = new HashMap<String, Object>();
private final HashMap<String, List<?>> uninitializedBroadcastVars = new HashMap<String, List<?>>();
public RuntimeUDFContext(String name, int numParallelSubtasks, int subtaskIndex, ClassLoader userCodeClassLoader,
ExecutionConfig executionConfig, Map<String, Accumulator<?,?>> accumulators) {
super(name, numParallelSubtasks, subtaskIndex, userCodeClassLoader, executionConfig, accumulators);
}
public RuntimeUDFContext(String name, int numParallelSubtasks, int subtaskIndex, ClassLoader userCodeClassLoader,
ExecutionConfig executionConfig, Map<String, Future<Path>> cpTasks, Map<String, Accumulator<?,?>> accumulators) {
super(name, numParallelSubtasks, subtaskIndex, userCodeClassLoader, executionConfig, accumulators, cpTasks);
}
@Override
@SuppressWarnings("unchecked")
......
......@@ -27,6 +27,10 @@ import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.InvalidProgramException;
......@@ -37,6 +41,7 @@ import org.apache.flink.api.common.accumulators.AccumulatorHelper;
import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.api.common.aggregators.AggregatorWithName;
import org.apache.flink.api.common.aggregators.ConvergenceCriterion;
import org.apache.flink.api.common.cache.DistributedCache;
import org.apache.flink.api.common.functions.IterationRuntimeContext;
import org.apache.flink.api.common.functions.RichFunction;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
......@@ -51,6 +56,8 @@ import org.apache.flink.api.common.operators.util.TypeComparable;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.core.fs.Path;
import org.apache.flink.core.fs.local.LocalFileSystem;
import org.apache.flink.types.Value;
import org.apache.flink.util.Visitor;
......@@ -64,6 +71,8 @@ public class CollectionExecutor {
private final Map<Operator<?>, List<?>> intermediateResults;
private final Map<String, Accumulator<?, ?>> accumulators;
private final Map<String, Future<Path>> cachedFiles;
private final Map<String, Value> previousAggregates;
......@@ -84,7 +93,7 @@ public class CollectionExecutor {
this.accumulators = new HashMap<String, Accumulator<?,?>>();
this.previousAggregates = new HashMap<String, Value>();
this.aggregators = new HashMap<String, Aggregator<?>>();
this.cachedFiles = new HashMap<String, Future<Path>>();
this.classLoader = getClass().getClassLoader();
}
......@@ -94,7 +103,7 @@ public class CollectionExecutor {
public JobExecutionResult execute(Plan program) throws Exception {
long startTime = System.currentTimeMillis();
initCache(program.getCachedFiles());
Collection<? extends GenericDataSinkBase<?>> sinks = program.getDataSinks();
for (Operator<?> sink : sinks) {
execute(sink);
......@@ -104,7 +113,14 @@ public class CollectionExecutor {
Map<String, Object> accumulatorResults = AccumulatorHelper.toResultMap(accumulators);
return new JobExecutionResult(null, endTime - startTime, accumulatorResults);
}
private void initCache(Set<Map.Entry<String, DistributedCache.DistributedCacheEntry>> files){
for(Map.Entry<String, DistributedCache.DistributedCacheEntry> file: files){
Future<Path> doNothing = new CompletedFuture(new Path(file.getValue().filePath));
cachedFiles.put(file.getKey(), doNothing);
}
};
private List<?> execute(Operator<?> operator) throws Exception {
return execute(operator, 0);
}
......@@ -165,8 +181,8 @@ public class CollectionExecutor {
// build the runtime context and compute broadcast variables, if necessary
RuntimeUDFContext ctx;
if (RichOutputFormat.class.isAssignableFrom(typedSink.getUserCodeWrapper().getUserCodeClass())) {
ctx = superStep == 0 ? new RuntimeUDFContext(typedSink.getName(), 1, 0, getClass().getClassLoader(), executionConfig, accumulators) :
new IterationRuntimeUDFContext(typedSink.getName(), 1, 0, classLoader, executionConfig, accumulators);
ctx = superStep == 0 ? new RuntimeUDFContext(typedSink.getName(), 1, 0, getClass().getClassLoader(), executionConfig, cachedFiles, accumulators) :
new IterationRuntimeUDFContext(typedSink.getName(), 1, 0, classLoader, executionConfig, cachedFiles, accumulators);
} else {
ctx = null;
}
......@@ -181,8 +197,8 @@ public class CollectionExecutor {
// build the runtime context and compute broadcast variables, if necessary
RuntimeUDFContext ctx;
if (RichInputFormat.class.isAssignableFrom(typedSource.getUserCodeWrapper().getUserCodeClass())) {
ctx = superStep == 0 ? new RuntimeUDFContext(source.getName(), 1, 0, getClass().getClassLoader(), executionConfig, accumulators) :
new IterationRuntimeUDFContext(source.getName(), 1, 0, classLoader, executionConfig, accumulators);
ctx = superStep == 0 ? new RuntimeUDFContext(source.getName(), 1, 0, getClass().getClassLoader(), executionConfig, cachedFiles, accumulators) :
new IterationRuntimeUDFContext(source.getName(), 1, 0, classLoader, executionConfig, cachedFiles, accumulators);
} else {
ctx = null;
}
......@@ -204,8 +220,10 @@ public class CollectionExecutor {
// build the runtime context and compute broadcast variables, if necessary
RuntimeUDFContext ctx;
if (RichFunction.class.isAssignableFrom(typedOp.getUserCodeWrapper().getUserCodeClass())) {
ctx = superStep == 0 ? new RuntimeUDFContext(operator.getName(), 1, 0, getClass().getClassLoader(), executionConfig, accumulators) :
new IterationRuntimeUDFContext(operator.getName(), 1, 0, classLoader, executionConfig, accumulators);
ctx = superStep == 0 ? new RuntimeUDFContext(operator.getName(), 1, 0, getClass()
.getClassLoader(), executionConfig, cachedFiles, accumulators) :
new IterationRuntimeUDFContext(operator.getName(), 1, 0, classLoader,
executionConfig, cachedFiles, accumulators);
for (Map.Entry<String, Operator<?>> bcInputs : operator.getBroadcastInputs().entrySet()) {
List<?> bcData = execute(bcInputs.getValue());
......@@ -243,8 +261,10 @@ public class CollectionExecutor {
// build the runtime context and compute broadcast variables, if necessary
RuntimeUDFContext ctx;
if (RichFunction.class.isAssignableFrom(typedOp.getUserCodeWrapper().getUserCodeClass())) {
ctx = superStep == 0 ? new RuntimeUDFContext(operator.getName(), 1, 0, classLoader, executionConfig, accumulators) :
new IterationRuntimeUDFContext(operator.getName(), 1, 0, classLoader, executionConfig, accumulators);
ctx = superStep == 0 ? new RuntimeUDFContext(operator.getName(), 1, 0, classLoader,
executionConfig, cachedFiles, accumulators) :
new IterationRuntimeUDFContext(operator.getName(), 1, 0, classLoader,
executionConfig, cachedFiles, accumulators);
for (Map.Entry<String, Operator<?>> bcInputs : operator.getBroadcastInputs().entrySet()) {
List<?> bcData = execute(bcInputs.getValue());
......@@ -500,8 +520,9 @@ public class CollectionExecutor {
private class IterationRuntimeUDFContext extends RuntimeUDFContext implements IterationRuntimeContext {
public IterationRuntimeUDFContext(String name, int numParallelSubtasks, int subtaskIndex, ClassLoader classloader,
ExecutionConfig executionConfig, Map<String, Accumulator<?,?>> accumulators) {
super(name, numParallelSubtasks, subtaskIndex, classloader, executionConfig, accumulators);
ExecutionConfig executionConfig, Map<String, Future<Path>> cpTasks, Map<String,
Accumulator<?,?>> accumulators) {
super(name, numParallelSubtasks, subtaskIndex, classloader, executionConfig, cpTasks, accumulators);
}
@Override
......@@ -521,4 +542,43 @@ public class CollectionExecutor {
return (T) previousAggregates.get(name);
}
}
private static final class CompletedFuture implements Future<Path>{
private final Path result;
public CompletedFuture(Path entry) {
try{
LocalFileSystem fs = (LocalFileSystem) entry.getFileSystem();
result = entry.isAbsolute() ? new Path(entry.toUri().getPath()): new Path(fs.getWorkingDirectory(),entry);
} catch (Exception e){
throw new RuntimeException("DistributedCache supports only local files for Collection Environments");
}
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
return false;
}
@Override
public boolean isCancelled() {
return false;
}
@Override
public boolean isDone() {
return true;
}
@Override
public Path get() throws InterruptedException, ExecutionException {
return result;
}
@Override
public Path get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
return get();
}
}
}
......@@ -24,10 +24,12 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.Future;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.core.fs.Path;
import org.junit.Test;
......@@ -36,7 +38,7 @@ public class RuntimeUDFContextTest {
@Test
public void testBroadcastVariableNotFound() {
try {
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Accumulator<?, ?>>());
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Future<Path>>(),new HashMap<String, Accumulator<?, ?>>());
try {
ctx.getBroadcastVariable("some name");
......@@ -66,7 +68,7 @@ public class RuntimeUDFContextTest {
@Test
public void testBroadcastVariableSimple() {
try {
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Accumulator<?, ?>>());
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>());
ctx.setBroadcastVariable("name1", Arrays.asList(1, 2, 3, 4));
ctx.setBroadcastVariable("name2", Arrays.asList(1.0, 2.0, 3.0, 4.0));
......@@ -100,7 +102,7 @@ public class RuntimeUDFContextTest {
@Test
public void testBroadcastVariableWithInitializer() {
try {
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Accumulator<?, ?>>());
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>());
ctx.setBroadcastVariable("name", Arrays.asList(1, 2, 3, 4));
......@@ -125,7 +127,7 @@ public class RuntimeUDFContextTest {
@Test
public void testResetBroadcastVariableWithInitializer() {
try {
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Accumulator<?, ?>>());
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>());
ctx.setBroadcastVariable("name", Arrays.asList(1, 2, 3, 4));
......@@ -148,7 +150,7 @@ public class RuntimeUDFContextTest {
@Test
public void testBroadcastVariableWithInitializerAndMismatch() {
try {
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Accumulator<?, ?>>());
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>());
ctx.setBroadcastVariable("name", Arrays.asList(1, 2, 3, 4));
......
......@@ -20,10 +20,12 @@
package org.apache.flink.api.common.io;
import java.util.HashMap;
import java.util.concurrent.Future;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.core.fs.Path;
import org.apache.flink.types.Value;
import org.junit.Assert;
import org.junit.Test;
......@@ -36,8 +38,7 @@ public class RichInputFormatTest {
@Test
public void testCheckRuntimeContextAccess() {
final SerializedInputFormat<Value> inputFormat = new SerializedInputFormat<Value>();
inputFormat.setRuntimeContext(new RuntimeUDFContext("test name", 3, 1,
getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Accumulator<?, ?>>()));
inputFormat.setRuntimeContext(new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>()));
Assert.assertEquals(inputFormat.getRuntimeContext().getIndexOfThisSubtask(), 1);
Assert.assertEquals(inputFormat.getRuntimeContext().getNumberOfParallelSubtasks(),3);
......
......@@ -20,10 +20,12 @@
package org.apache.flink.api.common.io;
import java.util.HashMap;
import java.util.concurrent.Future;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.core.fs.Path;
import org.apache.flink.types.Value;
import org.junit.Assert;
import org.junit.Test;
......@@ -36,8 +38,7 @@ public class RichOutputFormatTest {
@Test
public void testCheckRuntimeContextAccess() {
final SerializedOutputFormat<Value> inputFormat = new SerializedOutputFormat<Value>();
inputFormat.setRuntimeContext(new RuntimeUDFContext("test name", 3, 1,
getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Accumulator<?, ?>>()));
inputFormat.setRuntimeContext(new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>()));
Assert.assertEquals(inputFormat.getRuntimeContext().getIndexOfThisSubtask(), 1);
Assert.assertEquals(inputFormat.getRuntimeContext().getNumberOfParallelSubtasks(),3);
......
......@@ -26,10 +26,12 @@ import org.apache.flink.api.common.operators.util.TestNonRichOutputFormat;
import org.apache.flink.api.common.operators.util.TestNonRichInputFormat;
import org.apache.flink.api.common.operators.util.TestRichOutputFormat;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.core.fs.Path;
import org.apache.flink.types.Nothing;
import org.junit.Test;
import java.util.HashMap;
import java.util.concurrent.Future;
import static java.util.Arrays.asList;
import static org.junit.Assert.assertEquals;
......@@ -87,15 +89,16 @@ public class GenericDataSinkBaseTest implements java.io.Serializable {
ExecutionConfig executionConfig = new ExecutionConfig();
final HashMap<String, Accumulator<?, ?>> accumulatorMap = new HashMap<String, Accumulator<?, ?>>();
final HashMap<String, Future<Path>> cpTasks = new HashMap<>();
executionConfig.disableObjectReuse();
in.reset();
sink.executeOnCollections(asList(TestIOData.NAMES), new RuntimeUDFContext("test_sink", 1, 0, null, executionConfig, accumulatorMap), executionConfig);
sink.executeOnCollections(asList(TestIOData.NAMES), new RuntimeUDFContext("test_sink", 1, 0, null, executionConfig, cpTasks, accumulatorMap), executionConfig);
assertEquals(out.output, asList(TestIOData.RICH_NAMES));
executionConfig.enableObjectReuse();
out.clear();
in.reset();
sink.executeOnCollections(asList(TestIOData.NAMES), new RuntimeUDFContext("test_sink", 1, 0, null, executionConfig, accumulatorMap), executionConfig);
sink.executeOnCollections(asList(TestIOData.NAMES), new RuntimeUDFContext("test_sink", 1, 0, null, executionConfig, cpTasks, accumulatorMap), executionConfig);
assertEquals(out.output, asList(TestIOData.RICH_NAMES));
} catch(Exception e){
e.printStackTrace();
......
......@@ -25,10 +25,12 @@ import org.apache.flink.api.common.operators.util.TestIOData;
import org.apache.flink.api.common.operators.util.TestNonRichInputFormat;
import org.apache.flink.api.common.operators.util.TestRichInputFormat;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.core.fs.Path;
import org.junit.Test;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.Future;
import static java.util.Arrays.asList;
import static org.junit.Assert.assertEquals;
......@@ -73,13 +75,14 @@ public class GenericDataSourceBaseTest implements java.io.Serializable {
in, new OperatorInformation<String>(BasicTypeInfo.STRING_TYPE_INFO), "testSource");
final HashMap<String, Accumulator<?, ?>> accumulatorMap = new HashMap<String, Accumulator<?, ?>>();
final HashMap<String, Future<Path>> cpTasks = new HashMap<>();
ExecutionConfig executionConfig = new ExecutionConfig();
executionConfig.disableObjectReuse();
List<String> resultMutableSafe = source.executeOnCollections(new RuntimeUDFContext("test_source", 1, 0, null, executionConfig, accumulatorMap), executionConfig);
List<String> resultMutableSafe = source.executeOnCollections(new RuntimeUDFContext("test_source", 1, 0, null, executionConfig, cpTasks, accumulatorMap), executionConfig);
in.reset();
executionConfig.enableObjectReuse();
List<String> resultRegular = source.executeOnCollections(new RuntimeUDFContext("test_source", 1, 0, null, executionConfig, accumulatorMap), executionConfig);
List<String> resultRegular = source.executeOnCollections(new RuntimeUDFContext("test_source", 1, 0, null, executionConfig, cpTasks, accumulatorMap), executionConfig);
assertEquals(asList(TestIOData.RICH_NAMES), resultMutableSafe);
assertEquals(asList(TestIOData.RICH_NAMES), resultRegular);
......
......@@ -27,6 +27,7 @@ import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;
......@@ -36,6 +37,7 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.Future;
@SuppressWarnings("serial")
public class FlatMapOperatorCollectionTest implements Serializable {
......@@ -74,7 +76,7 @@ public class FlatMapOperatorCollectionTest implements Serializable {
}
// run on collections
final List<String> result = getTestFlatMapOperator(udf)
.executeOnCollections(input, new RuntimeUDFContext("Test UDF", 4, 0, null, executionConfig, new HashMap<String, Accumulator<?, ?>>()), executionConfig);
.executeOnCollections(input, new RuntimeUDFContext("Test UDF", 4, 0, null, executionConfig, new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>()), executionConfig);
Assert.assertEquals(input.size(), result.size());
Assert.assertEquals(input, result);
......
......@@ -28,6 +28,7 @@ import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.api.common.operators.BinaryOperatorInformation;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
import org.apache.flink.util.Collector;
import org.junit.Test;
......@@ -36,6 +37,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
@SuppressWarnings("serial")
......@@ -117,11 +119,13 @@ public class JoinOperatorBaseTest implements Serializable {
try {
final HashMap<String, Accumulator<?, ?>> accumulatorMap = new HashMap<String, Accumulator<?, ?>>();
final HashMap<String, Future<Path>> cpTasks = new HashMap<>();
ExecutionConfig executionConfig = new ExecutionConfig();
executionConfig.disableObjectReuse();
List<Integer> resultSafe = base.executeOnCollections(inputData1, inputData2, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, accumulatorMap), executionConfig);
List<Integer> resultSafe = base.executeOnCollections(inputData1, inputData2, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, cpTasks, accumulatorMap), executionConfig);
executionConfig.enableObjectReuse();
List<Integer> resultRegular = base.executeOnCollections(inputData1, inputData2, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, accumulatorMap), executionConfig);
List<Integer> resultRegular = base.executeOnCollections(inputData1, inputData2, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, cpTasks, accumulatorMap), executionConfig);
assertEquals(expected, resultSafe);
assertEquals(expected, resultRegular);
......
......@@ -24,6 +24,7 @@ import static java.util.Arrays.asList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.flink.api.common.ExecutionConfig;
......@@ -36,6 +37,7 @@ import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.base.MapOperatorBase;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
import org.junit.Test;
@SuppressWarnings("serial")
......@@ -105,11 +107,12 @@ public class MapOperatorTest implements java.io.Serializable {
List<String> input = new ArrayList<String>(asList("1", "2", "3", "4", "5", "6"));
final HashMap<String, Accumulator<?, ?>> accumulatorMap = new HashMap<String, Accumulator<?, ?>>();
final HashMap<String, Future<Path>> cpTasks = new HashMap<>();
ExecutionConfig executionConfig = new ExecutionConfig();
executionConfig.disableObjectReuse();
List<Integer> resultMutableSafe = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, accumulatorMap), executionConfig);
List<Integer> resultMutableSafe = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, cpTasks, accumulatorMap), executionConfig);
executionConfig.enableObjectReuse();
List<Integer> resultRegular = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, accumulatorMap), executionConfig);
List<Integer> resultRegular = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, cpTasks, accumulatorMap), executionConfig);
assertEquals(asList(1, 2, 3, 4, 5, 6), resultMutableSafe);
assertEquals(asList(1, 2, 3, 4, 5, 6), resultRegular);
......
......@@ -24,10 +24,12 @@ import static java.util.Arrays.asList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.core.fs.Path;
import org.apache.flink.util.Collector;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
......@@ -80,9 +82,9 @@ public class PartitionMapOperatorTest implements java.io.Serializable {
ExecutionConfig executionConfig = new ExecutionConfig();
executionConfig.disableObjectReuse();
List<Integer> resultMutableSafe = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, new HashMap<String, Accumulator<?, ?>>()), executionConfig);
List<Integer> resultMutableSafe = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>()), executionConfig);
executionConfig.enableObjectReuse();
List<Integer> resultRegular = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, new HashMap<String, Accumulator<?, ?>>()), executionConfig);
List<Integer> resultRegular = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>()), executionConfig);
assertEquals(asList(1, 2, 3, 4, 5, 6), resultMutableSafe);
assertEquals(asList(1, 2, 3, 4, 5, 6), resultRegular);
......
......@@ -29,6 +29,7 @@ import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.builder.Tuple2Builder;
import org.apache.flink.api.java.typeutils.TypeInfoParser;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;
......@@ -40,6 +41,7 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Future;
@SuppressWarnings("serial")
public class CoGroupOperatorCollectionTest implements Serializable {
......@@ -71,7 +73,8 @@ public class CoGroupOperatorCollectionTest implements Serializable {
ExecutionConfig executionConfig = new ExecutionConfig();
final HashMap<String, Accumulator<?, ?>> accumulators = new HashMap<String, Accumulator<?, ?>>();
final RuntimeContext ctx = new RuntimeUDFContext("Test UDF", 4, 0, null, executionConfig, accumulators);
final HashMap<String, Future<Path>> cpTasks = new HashMap<>();
final RuntimeContext ctx = new RuntimeUDFContext("Test UDF", 4, 0, null, executionConfig, cpTasks, accumulators);
{
SumCoGroup udf1 = new SumCoGroup();
......
......@@ -28,6 +28,7 @@ import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TypeInfoParser;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
import org.apache.flink.util.Collector;
import org.junit.Test;
......@@ -37,6 +38,7 @@ import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import static java.util.Arrays.asList;
......@@ -163,9 +165,9 @@ public class GroupReduceOperatorTest implements java.io.Serializable {
ExecutionConfig executionConfig = new ExecutionConfig();
executionConfig.disableObjectReuse();
List<Tuple2<String, Integer>> resultMutableSafe = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, new HashMap<String, Accumulator<?, ?>>()), executionConfig);
List<Tuple2<String, Integer>> resultMutableSafe = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>()), executionConfig);
executionConfig.enableObjectReuse();
List<Tuple2<String, Integer>> resultRegular = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, new HashMap<String, Accumulator<?, ?>>()), executionConfig);
List<Tuple2<String, Integer>> resultRegular = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>()), executionConfig);
Set<Tuple2<String, Integer>> resultSetMutableSafe = new HashSet<Tuple2<String, Integer>>(resultMutableSafe);
......
......@@ -28,6 +28,7 @@ import org.apache.flink.api.common.operators.BinaryOperatorInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.core.fs.Path;
import org.apache.flink.util.Collector;
import org.junit.Test;
......@@ -38,6 +39,7 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Future;
@SuppressWarnings({ "unchecked", "serial" })
public class JoinOperatorBaseTest implements Serializable {
......@@ -105,9 +107,9 @@ public class JoinOperatorBaseTest implements Serializable {
try {
ExecutionConfig executionConfig = new ExecutionConfig();
executionConfig.disableObjectReuse();
List<Tuple2<Double, String>> resultSafe = base.executeOnCollections(inputData1, inputData2, new RuntimeUDFContext("op", 1, 0, null, executionConfig, new HashMap<String, Accumulator<?, ?>>()), executionConfig);
List<Tuple2<Double, String>> resultSafe = base.executeOnCollections(inputData1, inputData2, new RuntimeUDFContext("op", 1, 0, null, executionConfig, new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>()), executionConfig);
executionConfig.enableObjectReuse();
List<Tuple2<Double, String>> resultRegular = base.executeOnCollections(inputData1, inputData2, new RuntimeUDFContext("op", 1, 0, null, executionConfig, new HashMap<String, Accumulator<?, ?>>()), executionConfig);
List<Tuple2<Double, String>> resultRegular = base.executeOnCollections(inputData1, inputData2, new RuntimeUDFContext("op", 1, 0, null, executionConfig, new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>()), executionConfig);
assertEquals(expected, new HashSet<Tuple2<Double, String>>(resultSafe));
assertEquals(expected, new HashSet<Tuple2<Double, String>>(resultRegular));
......
......@@ -28,6 +28,7 @@ import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TypeInfoParser;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
import org.junit.Test;
import java.util.ArrayList;
......@@ -35,6 +36,7 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import static java.util.Arrays.asList;
......@@ -140,9 +142,9 @@ public class ReduceOperatorTest implements java.io.Serializable {
ExecutionConfig executionConfig = new ExecutionConfig();
executionConfig.disableObjectReuse();
List<Tuple2<String, Integer>> resultMutableSafe = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, new HashMap<String, Accumulator<?, ?>>()), executionConfig);
List<Tuple2<String, Integer>> resultMutableSafe = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>()), executionConfig);
executionConfig.enableObjectReuse();
List<Tuple2<String, Integer>> resultRegular = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, new HashMap<String, Accumulator<?, ?>>()), executionConfig);
List<Tuple2<String, Integer>> resultRegular = op.executeOnCollections(input, new RuntimeUDFContext(taskName, 1, 0, null, executionConfig, new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>()), executionConfig);
Set<Tuple2<String, Integer>> resultSetMutableSafe = new HashSet<Tuple2<String, Integer>>(resultMutableSafe);
Set<Tuple2<String, Integer>> resultSetRegular = new HashSet<Tuple2<String, Integer>>(resultRegular);
......
......@@ -20,6 +20,7 @@ package org.apache.flink.runtime.iterative.task;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.core.fs.Path;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.flink.api.common.aggregators.Aggregator;
......@@ -55,6 +56,7 @@ import org.apache.flink.util.MutableObjectIterator;
import java.io.IOException;
import java.io.Serializable;
import java.util.Map;
import java.util.concurrent.Future;
/**
* The abstract base class for all tasks able to participate in an iteration.
......@@ -169,7 +171,8 @@ public abstract class AbstractIterativePactTask<S extends Function, OT> extends
public DistributedRuntimeUDFContext createRuntimeContext(String taskName) {
Environment env = getEnvironment();
return new IterativeRuntimeUdfContext(taskName, env.getNumberOfSubtasks(),
env.getIndexInSubtaskGroup(), getUserCodeClassLoader(), getExecutionConfig(), this.accumulatorMap);
env.getIndexInSubtaskGroup(), getUserCodeClassLoader(), getExecutionConfig(),
env.getDistributedCacheEntries(), this.accumulatorMap);
}
// --------------------------------------------------------------------------------------------
......@@ -359,9 +362,9 @@ public abstract class AbstractIterativePactTask<S extends Function, OT> extends
private class IterativeRuntimeUdfContext extends DistributedRuntimeUDFContext implements IterationRuntimeContext {
public IterativeRuntimeUdfContext(String name, int numParallelSubtasks, int subtaskIndex, ClassLoader userCodeClassLoader,
ExecutionConfig executionConfig,
ExecutionConfig executionConfig, Map<String, Future<Path>> cpTasks,
Map<String, Accumulator<?,?>> accumulatorMap) {
super(name, numParallelSubtasks, subtaskIndex, userCodeClassLoader, executionConfig, accumulatorMap);
super(name, numParallelSubtasks, subtaskIndex, userCodeClassLoader, executionConfig, cpTasks, accumulatorMap);
}
@Override
......
......@@ -41,12 +41,6 @@ public class DistributedRuntimeUDFContext extends AbstractRuntimeUDFContext {
private final HashMap<String, BroadcastVariableMaterialization<?, ?>> broadcastVars = new HashMap<String, BroadcastVariableMaterialization<?, ?>>();
public DistributedRuntimeUDFContext(String name, int numParallelSubtasks, int subtaskIndex, ClassLoader userCodeClassLoader,
ExecutionConfig executionConfig, Map<String, Accumulator<?,?>> accumulators) {
super(name, numParallelSubtasks, subtaskIndex, userCodeClassLoader, executionConfig, accumulators);
}
public DistributedRuntimeUDFContext(String name, int numParallelSubtasks, int subtaskIndex, ClassLoader userCodeClassLoader,
ExecutionConfig executionConfig, Map<String, Future<Path>> cpTasks, Map<String, Accumulator<?,?>> accumulators) {
super(name, numParallelSubtasks, subtaskIndex, userCodeClassLoader, executionConfig, accumulators, cpTasks);
......
......@@ -21,6 +21,7 @@ package org.apache.flink.tez.runtime;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.functions.Function;
import org.apache.flink.core.fs.Path;
import org.apache.flink.runtime.operators.PactDriver;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.tez.util.EncodingUtils;
......@@ -40,6 +41,7 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Future;
public class RegularProcessor<S extends Function, OT> extends AbstractLogicalIOProcessor {
......@@ -70,6 +72,7 @@ public class RegularProcessor<S extends Function, OT> extends AbstractLogicalIOP
getContext().getTaskIndex(),
getClass().getClassLoader(),
new ExecutionConfig(),
new HashMap<String, Future<Path>>(),
new HashMap<String, Accumulator<?, ?>>());
this.task = new TezTask<S, OT>(taskConfig, runtimeUdfContext, this.getContext().getTotalMemoryAvailableToTask());
......
......@@ -18,8 +18,14 @@
package org.apache.flink.test.iterative.aggregators;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.util.Random;
import org.apache.flink.api.common.functions.RichFilterFunction;
import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.test.util.MultipleProgramsTestBase;
import org.junit.After;
import org.junit.Assert;
......@@ -44,6 +50,8 @@ import org.apache.flink.api.java.operators.IterativeDataSet;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import static org.junit.Assert.assertEquals;
/**
* Test the functionality of aggregators in bulk and delta iterative cases.
*/
......@@ -54,6 +62,10 @@ public class AggregatorsITCase extends MultipleProgramsTestBase {
private static final int parallelism = 2;
private static final String NEGATIVE_ELEMENTS_AGGR = "count.negative.elements";
private static String testString = "Et tu, Brute?";
private static String testName = "testing_caesar";
private static String testPath;
public AggregatorsITCase(TestExecutionMode mode){
super(mode);
}
......@@ -66,7 +78,9 @@ public class AggregatorsITCase extends MultipleProgramsTestBase {
@Before
public void before() throws Exception{
resultPath = tempFolder.newFile().toURI().toString();
File tempFile = tempFolder.newFile();
testPath = tempFile.toString();
resultPath = tempFile.toURI().toString();
}
@After
......@@ -74,6 +88,35 @@ public class AggregatorsITCase extends MultipleProgramsTestBase {
compareResultsByLinesInMemory(expected, resultPath);
}
@Test
public void testDistributedCacheWithIterations() throws Exception{
File tempFile = new File(testPath);
FileWriter writer = new FileWriter(tempFile);
writer.write(testString);
writer.close();
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.registerCachedFile(resultPath, testName);
IterativeDataSet<Long> solution = env.fromElements(1L).iterate(2);
solution.closeWith(env.generateSequence(1,2).filter(new RichFilterFunction<Long>() {
@Override
public void open(Configuration parameters) throws Exception{
File file = getRuntimeContext().getDistributedCache().getFile(testName);
BufferedReader reader = new BufferedReader(new FileReader(file));
String output = reader.readLine();
reader.close();
assertEquals(output, testString);
}
@Override
public boolean filter(Long value) throws Exception {
return false;
}
}).withBroadcastSet(solution, "SOLUTION")).output(new DiscardingOutputFormat<Long>());
env.execute();
expected = testString; // this will be a useless verification now.
}
@Test
public void testAggregatorWithoutParameterForIterate() throws Exception {
/*
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册