提交 73493335 编写于 作者: S Stephan Ewen

[FLINK-1959] [runtime] Support accumulators in chained functions after a non-UDF operation

上级 cf4f22ea
...@@ -40,7 +40,8 @@ public class AccumulatorHelper { ...@@ -40,7 +40,8 @@ public class AccumulatorHelper {
if (ownAccumulator == null) { if (ownAccumulator == null) {
// Take over counter from chained task // Take over counter from chained task
target.put(otherEntry.getKey(), otherEntry.getValue()); target.put(otherEntry.getKey(), otherEntry.getValue());
} else { }
else {
// Both should have the same type // Both should have the same type
AccumulatorHelper.compareAccumulatorTypes(otherEntry.getKey(), AccumulatorHelper.compareAccumulatorTypes(otherEntry.getKey(),
ownAccumulator.getClass(), otherEntry.getValue().getClass()); ownAccumulator.getClass(), otherEntry.getValue().getClass());
...@@ -122,13 +123,14 @@ public class AccumulatorHelper { ...@@ -122,13 +123,14 @@ public class AccumulatorHelper {
return builder.toString(); return builder.toString();
} }
public static void resetAndClearAccumulators( public static void resetAndClearAccumulators(Map<String, Accumulator<?, ?>> accumulators) {
Map<String, Accumulator<?, ?>> accumulators) { if (accumulators != null) {
for (Map.Entry<String, Accumulator<?, ?>> entry : accumulators.entrySet()) { for (Map.Entry<String, Accumulator<?, ?>> entry : accumulators.entrySet()) {
entry.getValue().resetLocal(); entry.getValue().resetLocal();
} }
accumulators.clear(); accumulators.clear();
} }
}
public static Map<String, Accumulator<?, ?>> copy(final Map<String, Accumulator<?, public static Map<String, Accumulator<?, ?>> copy(final Map<String, Accumulator<?,
?>> accumulators) { ?>> accumulators) {
......
...@@ -25,6 +25,7 @@ import org.apache.flink.api.common.distributions.DataDistribution; ...@@ -25,6 +25,7 @@ import org.apache.flink.api.common.distributions.DataDistribution;
import org.apache.flink.api.common.functions.GroupCombineFunction; import org.apache.flink.api.common.functions.GroupCombineFunction;
import org.apache.flink.api.common.functions.Function; import org.apache.flink.api.common.functions.Function;
import org.apache.flink.api.common.functions.Partitioner; import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.FunctionUtils; import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeComparatorFactory; import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
...@@ -70,6 +71,7 @@ import org.slf4j.LoggerFactory; ...@@ -70,6 +71,7 @@ import org.slf4j.LoggerFactory;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
...@@ -508,14 +510,13 @@ public class RegularPactTask<S extends Function, OT> extends AbstractInvokable i ...@@ -508,14 +510,13 @@ public class RegularPactTask<S extends Function, OT> extends AbstractInvokable i
// JobManager. close() has been called earlier for all involved UDFs // JobManager. close() has been called earlier for all involved UDFs
// (using this.stub.close() and closeChainedTasks()), so UDFs can no longer // (using this.stub.close() and closeChainedTasks()), so UDFs can no longer
// modify accumulators; // modify accumulators;
if (this.stub != null) {
// collect the counters from the stub // collect the counters from the udf in the core driver
if (FunctionUtils.getFunctionRuntimeContext(this.stub, this.runtimeUdfContext) != null) {
Map<String, Accumulator<?, ?>> accumulators = Map<String, Accumulator<?, ?>> accumulators =
FunctionUtils.getFunctionRuntimeContext(this.stub, this.runtimeUdfContext).getAllAccumulators(); FunctionUtils.getFunctionRuntimeContext(this.stub, this.runtimeUdfContext).getAllAccumulators();
RegularPactTask.reportAndClearAccumulators(getEnvironment(), accumulators, this.chainedTasks);
} // collect accumulators from chained tasks and report them
} reportAndClearAccumulators(getEnvironment(), accumulators, this.chainedTasks);
} }
catch (Exception ex) { catch (Exception ex) {
// close the input, but do not report any exceptions, since we already have another root cause // close the input, but do not report any exceptions, since we already have another root cause
...@@ -572,16 +573,25 @@ public class RegularPactTask<S extends Function, OT> extends AbstractInvokable i ...@@ -572,16 +573,25 @@ public class RegularPactTask<S extends Function, OT> extends AbstractInvokable i
// We can merge here the accumulators from the stub and the chained // We can merge here the accumulators from the stub and the chained
// tasks. Type conflicts can occur here if counters with same name but // tasks. Type conflicts can occur here if counters with same name but
// different type were used. // different type were used.
if (!chainedTasks.isEmpty()) {
if (accumulators == null) {
accumulators = new HashMap<String, Accumulator<?, ?>>();
}
for (ChainedDriver<?, ?> chainedTask : chainedTasks) { for (ChainedDriver<?, ?> chainedTask : chainedTasks) {
if (FunctionUtils.getFunctionRuntimeContext(chainedTask.getStub(), null) != null) { RuntimeContext rc = FunctionUtils.getFunctionRuntimeContext(chainedTask.getStub(), null);
Map<String, Accumulator<?, ?>> chainedAccumulators = if (rc != null) {
FunctionUtils.getFunctionRuntimeContext(chainedTask.getStub(), null).getAllAccumulators(); Map<String, Accumulator<?, ?>> chainedAccumulators = rc.getAllAccumulators();
if (chainedAccumulators != null) {
AccumulatorHelper.mergeInto(accumulators, chainedAccumulators); AccumulatorHelper.mergeInto(accumulators, chainedAccumulators);
} }
} }
}
}
// Don't report if the UDF didn't collect any accumulators // Don't report if the UDF didn't collect any accumulators
if (accumulators.size() == 0) { if (accumulators == null || accumulators.size() == 0) {
return; return;
} }
...@@ -592,9 +602,11 @@ public class RegularPactTask<S extends Function, OT> extends AbstractInvokable i ...@@ -592,9 +602,11 @@ public class RegularPactTask<S extends Function, OT> extends AbstractInvokable i
// (e.g. in iterations) and we don't want to count twice. This may not be // (e.g. in iterations) and we don't want to count twice. This may not be
// done before sending // done before sending
AccumulatorHelper.resetAndClearAccumulators(accumulators); AccumulatorHelper.resetAndClearAccumulators(accumulators);
for (ChainedDriver<?, ?> chainedTask : chainedTasks) { for (ChainedDriver<?, ?> chainedTask : chainedTasks) {
if (FunctionUtils.getFunctionRuntimeContext(chainedTask.getStub(), null) != null) { RuntimeContext rc = FunctionUtils.getFunctionRuntimeContext(chainedTask.getStub(), null);
AccumulatorHelper.resetAndClearAccumulators(FunctionUtils.getFunctionRuntimeContext(chainedTask.getStub(), null).getAllAccumulators()); if (rc != null) {
AccumulatorHelper.resetAndClearAccumulators(rc.getAllAccumulators());
} }
} }
} }
...@@ -1140,7 +1152,7 @@ public class RegularPactTask<S extends Function, OT> extends AbstractInvokable i ...@@ -1140,7 +1152,7 @@ public class RegularPactTask<S extends Function, OT> extends AbstractInvokable i
} catch (InterruptedException iex) { } catch (InterruptedException iex) {
throw new RuntimeException("Interrupted while waiting for input " + index + " to become available."); throw new RuntimeException("Interrupted while waiting for input " + index + " to become available.");
} catch (IOException ioex) { } catch (IOException ioex) {
throw new RuntimeException("An I/O Exception occurred whily obaining input " + index + "."); throw new RuntimeException("An I/O Exception occurred while obtaining input " + index + ".");
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册