提交 486cc908 编写于 作者: A Aljoscha Krettek 提交者: Stephan Ewen

[FLINK-1325] [Java API] Add Java ClosureCleaner

This closes #269
上级 026311ae
......@@ -52,6 +52,12 @@ under the License.
<artifactId>kryo</artifactId>
<version>2.24.0</version>
</dependency>
<dependency>
<groupId>org.ow2.asm</groupId>
<artifactId>asm</artifactId>
<version>4.0</version>
</dependency>
<!-- guava needs to be in "provided" scope, to make sure it is not included into the jars by the shading -->
<dependency>
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.util.InstantiationUtil;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.lang.reflect.Field;
public class ClosureCleaner {
private static Logger LOG = LoggerFactory.getLogger(ClosureCleaner.class);
private static ClassReader getClassReader(Class<?> cls) {
String className = cls.getName().replaceFirst("^.*\\.", "") + ".class";
try {
return new ClassReader(cls.getResourceAsStream(className));
} catch (IOException e) {
throw new RuntimeException("Could not create ClassReader: " + e);
}
}
public static void clean(Object func, boolean checkSerializable) {
Class<?> cls = func.getClass();
String this0Name = null;
// First find the field name of the "this$0" field, this can
// be "field$x" depending on the nesting
for (Field f: cls.getDeclaredFields()) {
if (f.getName().startsWith("this$")) {
// found our field:
this0Name = f.getName();
cleanThis0(func, cls, f.getName());
}
}
if (checkSerializable) {
ensureSerializable(func);
}
}
private static void cleanThis0(Object func, Class<?> cls, String this0Name) {
This0AccessFinder this0Finder = new This0AccessFinder(this0Name);
getClassReader(cls).accept(this0Finder, 0);
if (LOG.isDebugEnabled()) {
LOG.debug(this0Name + " is accessed: " + this0Finder.isThis0Accessed());
}
if (!this0Finder.isThis0Accessed()) {
Field this0;
try {
this0 = func.getClass().getDeclaredField(this0Name);
} catch (NoSuchFieldException e) {
// has no this$0, just return
throw new RuntimeException("Could not set " + this0Name + ": " + e);
}
this0.setAccessible(true);
try {
this0.set(func, null);
} catch (IllegalAccessException e) {
// should not happen, since we use setAccessible
throw new RuntimeException("Could not set " + this0Name + ": " + e);
}
}
}
public static void ensureSerializable(Object func) {
try {
InstantiationUtil.serializeObject(func);
} catch (Exception e) {
throw new InvalidProgramException("Task " + func + " not serializable: ", e);
}
}
}
class This0AccessFinder extends ClassVisitor {
private boolean isThis0Accessed = false;
private String this0Name;
public This0AccessFinder(String this0Name) {
super(Opcodes.ASM4);
this.this0Name = this0Name;
}
public boolean isThis0Accessed() {
return isThis0Accessed;
}
@Override
public MethodVisitor visitMethod(int access, String name, String desc, String sig, String[] exceptions) {
return new MethodVisitor(Opcodes.ASM4) {
@Override
public void visitFieldInsn(int op, String owner, String name, String desc) {
if (op == Opcodes.GETFIELD && name.equals(this0Name)) {
isThis0Accessed = true;
}
}
};
}
}
\ No newline at end of file
......@@ -92,8 +92,7 @@ public abstract class DataSet<T> {
private final ExecutionEnvironment context;
private final TypeInformation<T> type;
protected DataSet(ExecutionEnvironment context, TypeInformation<T> type) {
if (context == null) {
throw new NullPointerException("context is null");
......@@ -128,7 +127,15 @@ public abstract class DataSet<T> {
public TypeInformation<T> getType() {
return this.type;
}
public <F> F clean(F f) {
if (getExecutionEnvironment().getConfig().isClosureCleanerEnabled()) {
ClosureCleaner.clean(f, true);
}
ClosureCleaner.ensureSerializable(f);
return f;
}
// --------------------------------------------------------------------------------------------
// Filter & Transformations
// --------------------------------------------------------------------------------------------
......@@ -152,7 +159,7 @@ public abstract class DataSet<T> {
TypeInformation<R> resultType = TypeExtractor.getMapReturnTypes(mapper, this.getType());
return new MapOperator<T, R>(this, resultType, mapper, Utils.getCallLocationName());
return new MapOperator<T, R>(this, resultType, clean(mapper), Utils.getCallLocationName());
}
......@@ -180,7 +187,7 @@ public abstract class DataSet<T> {
throw new NullPointerException("MapPartition function must not be null.");
}
TypeInformation<R> resultType = TypeExtractor.getMapPartitionReturnTypes(mapPartition, this.getType());
return new MapPartitionOperator<T, R>(this, resultType, mapPartition, Utils.getCallLocationName());
return new MapPartitionOperator<T, R>(this, resultType, clean(mapPartition), Utils.getCallLocationName());
}
/**
......@@ -201,7 +208,7 @@ public abstract class DataSet<T> {
}
TypeInformation<R> resultType = TypeExtractor.getFlatMapReturnTypes(flatMapper, this.getType());
return new FlatMapOperator<T, R>(this, resultType, flatMapper, Utils.getCallLocationName());
return new FlatMapOperator<T, R>(this, resultType, clean(flatMapper), Utils.getCallLocationName());
}
/**
......@@ -221,7 +228,7 @@ public abstract class DataSet<T> {
if (filter == null) {
throw new NullPointerException("Filter function must not be null.");
}
return new FilterOperator<T>(this, filter, Utils.getCallLocationName());
return new FilterOperator<T>(this, clean(filter), Utils.getCallLocationName());
}
......@@ -335,7 +342,7 @@ public abstract class DataSet<T> {
if (reducer == null) {
throw new NullPointerException("Reduce function must not be null.");
}
return new ReduceOperator<T>(this, reducer, Utils.getCallLocationName());
return new ReduceOperator<T>(this, clean(reducer), Utils.getCallLocationName());
}
/**
......@@ -356,7 +363,7 @@ public abstract class DataSet<T> {
throw new NullPointerException("GroupReduce function must not be null.");
}
TypeInformation<R> resultType = TypeExtractor.getGroupReduceReturnTypes(reducer, this.getType());
return new GroupReduceOperator<T, R>(this, resultType, reducer, Utils.getCallLocationName());
return new GroupReduceOperator<T, R>(this, resultType, clean(reducer), Utils.getCallLocationName());
}
/**
......@@ -532,7 +539,7 @@ public abstract class DataSet<T> {
*/
public <K> UnsortedGrouping<T> groupBy(KeySelector<T, K> keyExtractor) {
TypeInformation<K> keyType = TypeExtractor.getKeySelectorTypes(keyExtractor, type);
return new UnsortedGrouping<T>(this, new Keys.SelectorFunctionKeys<T, K>(keyExtractor, getType(), keyType));
return new UnsortedGrouping<T>(this, new Keys.SelectorFunctionKeys<T, K>(clean(keyExtractor), getType(), keyType));
}
/**
......@@ -970,7 +977,7 @@ public abstract class DataSet<T> {
*/
public <K extends Comparable<K>> PartitionOperator<T> partitionByHash(KeySelector<T, K> keyExtractor) {
final TypeInformation<K> keyType = TypeExtractor.getKeySelectorTypes(keyExtractor, type);
return new PartitionOperator<T>(this, PartitionMethod.HASH, new Keys.SelectorFunctionKeys<T, K>(keyExtractor, this.getType(), keyType), Utils.getCallLocationName());
return new PartitionOperator<T>(this, PartitionMethod.HASH, new Keys.SelectorFunctionKeys<T, K>(clean(keyExtractor), this.getType(), keyType), Utils.getCallLocationName());
}
/**
......@@ -984,7 +991,7 @@ public abstract class DataSet<T> {
* @return The partitioned DataSet.
*/
public <K> PartitionOperator<T> partitionCustom(Partitioner<K> partitioner, int field) {
return new PartitionOperator<T>(this, new Keys.ExpressionKeys<T>(new int[] {field}, getType(), false), partitioner, Utils.getCallLocationName());
return new PartitionOperator<T>(this, new Keys.ExpressionKeys<T>(new int[] {field}, getType(), false), clean(partitioner), Utils.getCallLocationName());
}
/**
......@@ -998,7 +1005,7 @@ public abstract class DataSet<T> {
* @return The partitioned DataSet.
*/
public <K> PartitionOperator<T> partitionCustom(Partitioner<K> partitioner, String field) {
return new PartitionOperator<T>(this, new Keys.ExpressionKeys<T>(new String[] {field}, getType()), partitioner, Utils.getCallLocationName());
return new PartitionOperator<T>(this, new Keys.ExpressionKeys<T>(new String[] {field}, getType()), clean(partitioner), Utils.getCallLocationName());
}
/**
......@@ -1017,7 +1024,7 @@ public abstract class DataSet<T> {
*/
public <K extends Comparable<K>> PartitionOperator<T> partitionCustom(Partitioner<K> partitioner, KeySelector<T, K> keyExtractor) {
final TypeInformation<K> keyType = TypeExtractor.getKeySelectorTypes(keyExtractor, type);
return new PartitionOperator<T>(this, new Keys.SelectorFunctionKeys<T, K>(keyExtractor, this.getType(), keyType), partitioner, Utils.getCallLocationName());
return new PartitionOperator<T>(this, new Keys.SelectorFunctionKeys<T, K>(keyExtractor, this.getType(), keyType), clean(partitioner), Utils.getCallLocationName());
}
/**
......@@ -1095,7 +1102,7 @@ public abstract class DataSet<T> {
* @see TextOutputFormat
*/
public DataSink<String> writeAsFormattedText(String filePath, WriteMode writeMode, final TextFormatter<T> formatter) {
return this.map(new FormattingMapper<T>(formatter)).writeAsText(filePath, writeMode);
return this.map(new FormattingMapper<T>(clean(formatter))).writeAsText(filePath, writeMode);
}
/**
......
......@@ -149,7 +149,7 @@ public class CoGroupOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1, I2, OU
keys1.validateCustomPartitioner(partitioner, null);
keys2.validateCustomPartitioner(partitioner, null);
}
this.customPartitioner = partitioner;
this.customPartitioner = getInput1().clean(partitioner);
return this;
}
......@@ -590,7 +590,7 @@ public class CoGroupOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1, I2, OU
keys1.validateCustomPartitioner(partitioner, null);
keys2.validateCustomPartitioner(partitioner, null);
}
this.customPartitioner = partitioner;
this.customPartitioner = input1.clean(partitioner);
return this;
}
......@@ -619,7 +619,7 @@ public class CoGroupOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1, I2, OU
}
TypeInformation<R> returnType = TypeExtractor.getCoGroupReturnTypes(function, input1.getType(), input2.getType());
return new CoGroupOperator<I1, I2, R>(input1, input2, keys1, keys2, function, returnType,
return new CoGroupOperator<I1, I2, R>(input1, input2, keys1, keys2, input1.clean(function), returnType,
groupSortKeyOrderFirst, groupSortKeyOrderSecond,
customPartitioner, Utils.getCallLocationName());
}
......
......@@ -138,7 +138,7 @@ public class CrossOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1, I2, OUT,
throw new NullPointerException("Cross function must not be null.");
}
TypeInformation<R> returnType = TypeExtractor.getCrossReturnTypes(function, input1.getType(), input2.getType());
return new CrossOperator<I1, I2, R>(input1, input2, function, returnType, Utils.getCallLocationName());
return new CrossOperator<I1, I2, R>(input1, input2, clean(function), returnType, Utils.getCallLocationName());
}
/**
......
......@@ -149,7 +149,7 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
keys1.validateCustomPartitioner(partitioner, null);
keys2.validateCustomPartitioner(partitioner, null);
}
this.customPartitioner = partitioner;
this.customPartitioner = getInput1().clean(partitioner);
return this;
}
......@@ -520,14 +520,14 @@ public abstract class JoinOperator<I1, I2, OUT> extends TwoInputUdfOperator<I1,
throw new NullPointerException("Join function must not be null.");
}
TypeInformation<R> returnType = TypeExtractor.getFlatJoinReturnTypes(function, getInput1Type(), getInput2Type());
return new EquiJoin<I1, I2, R>(getInput1(), getInput2(), getKeys1(), getKeys2(), function, returnType, getJoinHint(), Utils.getCallLocationName());
return new EquiJoin<I1, I2, R>(getInput1(), getInput2(), getKeys1(), getKeys2(), clean(function), returnType, getJoinHint(), Utils.getCallLocationName());
}
public <R> EquiJoin<I1, I2, R> with (JoinFunction<I1, I2, R> function) {
if (function == null) {
throw new NullPointerException("Join function must not be null.");
}
FlatJoinFunction<I1, I2, R> generatedFunction = new WrappingFlatJoinFunction<I1, I2, R>(function);
FlatJoinFunction<I1, I2, R> generatedFunction = new WrappingFlatJoinFunction<I1, I2, R>(clean(function));
TypeInformation<R> returnType = TypeExtractor.getJoinReturnTypes(function, getInput1Type(), getInput2Type());
return new EquiJoin<I1, I2, R>(getInput1(), getInput2(), getKeys1(), getKeys2(), generatedFunction, function, returnType, getJoinHint(), Utils.getCallLocationName());
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.functions;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.ClosureCleaner;
import org.junit.Assert;
import org.junit.Test;
import java.io.Serializable;
public class ClosureCleanerTest {
@Test(expected = InvalidProgramException.class)
public void testNonSerializable() throws Exception {
MapCreator creator = new NonSerializableMapCreator();
MapFunction<Integer, Integer> map = creator.getMap();
ClosureCleaner.ensureSerializable(map);
int result = map.map(3);
Assert.assertEquals(result, 4);
}
@Test
public void testCleanedNonSerializable() throws Exception {
MapCreator creator = new NonSerializableMapCreator();
MapFunction<Integer, Integer> map = creator.getMap();
ClosureCleaner.clean(map, true);
int result = map.map(3);
Assert.assertEquals(result, 4);
}
@Test
public void testSerializable() throws Exception {
MapCreator creator = new SerializableMapCreator(1);
MapFunction<Integer, Integer> map = creator.getMap();
ClosureCleaner.clean(map, true);
int result = map.map(3);
Assert.assertEquals(result, 4);
}
@Test
public void testNestedSerializable() throws Exception {
MapCreator creator = new NestedSerializableMapCreator(1);
MapFunction<Integer, Integer> map = creator.getMap();
ClosureCleaner.clean(map, true);
ClosureCleaner.ensureSerializable(map);
int result = map.map(3);
Assert.assertEquals(result, 4);
}
@Test(expected = InvalidProgramException.class)
public void testNestedNonSerializable() throws Exception {
MapCreator creator = new NestedNonSerializableMapCreator(1);
MapFunction<Integer, Integer> map = creator.getMap();
ClosureCleaner.clean(map, true);
ClosureCleaner.ensureSerializable(map);
int result = map.map(3);
Assert.assertEquals(result, 4);
}
}
interface MapCreator {
MapFunction<Integer, Integer> getMap();
}
class NonSerializableMapCreator implements MapCreator {
@Override
public MapFunction<Integer, Integer> getMap() {
return new MapFunction<Integer, Integer>() {
@Override
public Integer map(Integer value) throws Exception {
return value + 1;
}
};
}
}
class SerializableMapCreator implements MapCreator, Serializable {
private int add = 0;
public SerializableMapCreator(int add) {
this.add = add;
}
@Override
public MapFunction<Integer, Integer> getMap() {
return new MapFunction<Integer, Integer>() {
@Override
public Integer map(Integer value) throws Exception {
return value + add;
}
};
}
}
class NestedSerializableMapCreator implements MapCreator, Serializable {
private int add = 0;
private InnerSerializableMapCreator inner;
public NestedSerializableMapCreator(int add) {
this.add = add;
this.inner = new InnerSerializableMapCreator();
}
@Override
public MapFunction<Integer, Integer> getMap() {
return inner.getMap();
}
class InnerSerializableMapCreator implements MapCreator, Serializable {
@Override
public MapFunction<Integer, Integer> getMap() {
return new MapFunction<Integer, Integer>() {
@Override
public Integer map(Integer value) throws Exception {
return value + add;
}
};
}
}
}
class NestedNonSerializableMapCreator implements MapCreator {
private int add = 0;
private InnerSerializableMapCreator inner;
public NestedNonSerializableMapCreator(int add) {
this.add = add;
this.inner = new InnerSerializableMapCreator();
}
@Override
public MapFunction<Integer, Integer> getMap() {
return inner.getMap();
}
class InnerSerializableMapCreator implements MapCreator, Serializable {
@Override
public MapFunction<Integer, Integer> getMap() {
return new MapFunction<Integer, Integer>() {
@Override
public Integer map(Integer value) throws Exception {
return value + getMeTheAdd();
}
};
}
public int getMeTheAdd() {
return add;
}
}
}
......@@ -172,7 +172,7 @@ object ClosureCleaner {
}
}
private def ensureSerializable(func: AnyRef) {
def ensureSerializable(func: AnyRef) {
try {
InstantiationUtil.serializeObject(func)
} catch {
......
......@@ -117,6 +117,7 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
if (set.getExecutionEnvironment.getConfig.isClosureCleanerEnabled) {
ClosureCleaner.clean(f, checkSerializable)
}
ClosureCleaner.ensureSerializable(f)
f
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册