提交 069653ba 编写于 作者: R robm

8172299: Improve class processing

Reviewed-by: rriggs
上级 bdbe6537
/* /*
* Copyright (c) 1996, 2016, Oracle and/or its affiliates. All rights reserved. * Copyright (c) 1996, 2017, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
* *
* This code is free software; you can redistribute it and/or modify it * This code is free software; you can redistribute it and/or modify it
...@@ -1774,12 +1774,19 @@ public class ObjectInputStream ...@@ -1774,12 +1774,19 @@ public class ObjectInputStream
} catch (ClassNotFoundException ex) { } catch (ClassNotFoundException ex) {
resolveEx = ex; resolveEx = ex;
} }
skipCustomData();
desc.initProxy(cl, resolveEx, readClassDesc(false)); // Call filterCheck on the class before reading anything else
filterCheck(cl, -1);
skipCustomData();
// Call filterCheck on the definition try {
filterCheck(desc.forClass(), -1); totalObjectRefs++;
depth++;
desc.initProxy(cl, resolveEx, readClassDesc(false));
} finally {
depth--;
}
handles.finish(descHandle); handles.finish(descHandle);
passHandle = descHandle; passHandle = descHandle;
...@@ -1824,12 +1831,19 @@ public class ObjectInputStream ...@@ -1824,12 +1831,19 @@ public class ObjectInputStream
} catch (ClassNotFoundException ex) { } catch (ClassNotFoundException ex) {
resolveEx = ex; resolveEx = ex;
} }
skipCustomData();
desc.initNonProxy(readDesc, cl, resolveEx, readClassDesc(false)); // Call filterCheck on the class before reading anything else
filterCheck(cl, -1);
skipCustomData();
// Call filterCheck on the definition try {
filterCheck(desc.forClass(), -1); totalObjectRefs++;
depth++;
desc.initNonProxy(readDesc, cl, resolveEx, readClassDesc(false));
} finally {
depth--;
}
handles.finish(descHandle); handles.finish(descHandle);
passHandle = descHandle; passHandle = descHandle;
......
/* /*
* Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved. * Copyright (c) 2016, 2017, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
* *
* This code is free software; you can redistribute it and/or modify it * This code is free software; you can redistribute it and/or modify it
...@@ -33,10 +33,11 @@ import java.lang.invoke.SerializedLambda; ...@@ -33,10 +33,11 @@ import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Constructor; import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException; import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Proxy; import java.lang.reflect.Proxy;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashSet; import java.util.HashSet;
import java.util.Hashtable; import java.util.Hashtable;
import java.util.Set; import java.util.List;
import java.util.concurrent.atomic.LongAdder; import java.util.concurrent.atomic.LongAdder;
import sun.misc.ObjectInputFilter; import sun.misc.ObjectInputFilter;
...@@ -155,26 +156,33 @@ public class SerialFilterTest implements Serializable { ...@@ -155,26 +156,33 @@ public class SerialFilterTest implements Serializable {
Runnable runnable = (Runnable & Serializable) SerialFilterTest::noop; Runnable runnable = (Runnable & Serializable) SerialFilterTest::noop;
Object[][] objects = { Object[][] objects = {
{ null, 0, -1, 0, 0, 0, { null, 0, -1, 0, 0, 0,
new HashSet<>()}, // no callback, no values Arrays.asList()}, // no callback, no values
{ objArray, 3, 7, 8, 2, 55, { objArray, 3, 7, 9, 2, 55,
new HashSet<>(Arrays.asList(objArray.getClass()))}, Arrays.asList(objArray.getClass(), objArray.getClass())},
{ Object[].class, 1, -1, 1, 1, 40, { Object[].class, 1, -1, 1, 1, 38,
new HashSet<>(Arrays.asList(Object[].class))}, Arrays.asList(Object[].class)},
{ new SerialFilterTest(), 1, -1, 1, 1, 37, { new SerialFilterTest(), 1, -1, 1, 1, 35,
new HashSet<>(Arrays.asList(SerialFilterTest.class))}, Arrays.asList(SerialFilterTest.class)},
{ new LongAdder(), 2, -1, 1, 1, 93, { new LongAdder(), 2, -1, 2, 1, 93,
new HashSet<>(Arrays.asList(LongAdder.class, serClass))}, Arrays.asList(serClass, LongAdder.class)},
{ new byte[14], 2, 14, 1, 1, 27, { new byte[14], 2, 14, 2, 1, 27,
new HashSet<>(Arrays.asList(byteArray.getClass()))}, Arrays.asList(byteArray.getClass(), byteArray.getClass())},
{ runnable, 13, 0, 10, 2, 514, { runnable, 13, 0, 13, 2, 514,
new HashSet<>(Arrays.asList(java.lang.invoke.SerializedLambda.class, Arrays.asList(java.lang.invoke.SerializedLambda.class,
objArray.getClass(),
objArray.getClass(),
SerialFilterTest.class, SerialFilterTest.class,
objArray.getClass()))}, java.lang.invoke.SerializedLambda.class)},
{ deepHashSet(10), 48, -1, 49, 11, 619, { deepHashSet(10), 48, -1, 50, 11, 619,
new HashSet<>(Arrays.asList(HashSet.class))}, Arrays.asList(HashSet.class)},
{ proxy.getClass(), 3, -1, 1, 1, 114, { proxy.getClass(), 3, -1, 2, 2, 112,
new HashSet<>(Arrays.asList(Runnable.class, Arrays.asList(Runnable.class,
java.lang.reflect.Proxy.class))}, java.lang.reflect.Proxy.class,
java.lang.reflect.Proxy.class)},
{ new F(), 6, -1, 6, 6, 202,
Arrays.asList(F.class, E.class, D.class,
C.class, B.class, A.class)},
}; };
return objects; return objects;
} }
...@@ -213,11 +221,12 @@ public class SerialFilterTest implements Serializable { ...@@ -213,11 +221,12 @@ public class SerialFilterTest implements Serializable {
@Test(dataProvider="Objects") @Test(dataProvider="Objects")
public static void t1(Object object, public static void t1(Object object,
long count, long maxArray, long maxRefs, long maxDepth, long maxBytes, long count, long maxArray, long maxRefs, long maxDepth, long maxBytes,
Set<Class<?>> classes) throws IOException { List<Class<?>> classes) throws IOException {
byte[] bytes = writeObjects(object); byte[] bytes = writeObjects(object);
Validator validator = new Validator(); Validator validator = new Validator();
validate(bytes, validator); validate(bytes, validator);
System.out.printf("v: %s%n", validator); System.out.printf("v: %s%n", validator);
Assert.assertEquals(validator.count, count, "callback count wrong"); Assert.assertEquals(validator.count, count, "callback count wrong");
Assert.assertEquals(validator.classes, classes, "classes mismatch"); Assert.assertEquals(validator.classes, classes, "classes mismatch");
Assert.assertEquals(validator.maxArray, maxArray, "maxArray mismatch"); Assert.assertEquals(validator.maxArray, maxArray, "maxArray mismatch");
...@@ -411,7 +420,7 @@ public class SerialFilterTest implements Serializable { ...@@ -411,7 +420,7 @@ public class SerialFilterTest implements Serializable {
*/ */
static class Validator implements ObjectInputFilter { static class Validator implements ObjectInputFilter {
long count; // Count of calls to checkInput long count; // Count of calls to checkInput
HashSet<Class<?>> classes = new HashSet<>(); List<Class<?>> classes = new ArrayList<>();
long maxArray = -1; long maxArray = -1;
long maxRefs; long maxRefs;
long maxDepth; long maxDepth;
...@@ -422,16 +431,20 @@ public class SerialFilterTest implements Serializable { ...@@ -422,16 +431,20 @@ public class SerialFilterTest implements Serializable {
@Override @Override
public ObjectInputFilter.Status checkInput(FilterInfo filter) { public ObjectInputFilter.Status checkInput(FilterInfo filter) {
Class<?> serialClass = filter.serialClass();
System.out.printf(" checkInput: class: %s, arrayLen: %d, refs: %d, depth: %d, bytes; %d%n",
serialClass, filter.arrayLength(), filter.references(),
filter.depth(), filter.streamBytes());
count++; count++;
if (filter.serialClass() != null) { if (serialClass != null) {
if (filter.serialClass().getName().contains("$$Lambda$")) { if (serialClass.getName().contains("$$Lambda$")) {
// TBD: proper identification of serialized Lambdas? // TBD: proper identification of serialized Lambdas?
// Fold the serialized Lambda into the SerializedLambda type // Fold the serialized Lambda into the SerializedLambda type
classes.add(SerializedLambda.class); classes.add(SerializedLambda.class);
} else if (Proxy.isProxyClass(filter.serialClass())) { } else if (Proxy.isProxyClass(serialClass)) {
classes.add(Proxy.class); classes.add(Proxy.class);
} else { } else {
classes.add(filter.serialClass()); classes.add(serialClass);
} }
} }
...@@ -591,7 +604,8 @@ public class SerialFilterTest implements Serializable { ...@@ -591,7 +604,8 @@ public class SerialFilterTest implements Serializable {
// a stream of exactly the size requested. // a stream of exactly the size requested.
return genMaxBytesObject(allowed, value); return genMaxBytesObject(allowed, value);
} else if (pattern.startsWith("maxrefs=")) { } else if (pattern.startsWith("maxrefs=")) {
Object[] array = new Object[allowed ? (int)value - 1 : (int)value]; // 4 references to classes in addition to the array contents
Object[] array = new Object[allowed ? (int)value - 4 : (int)value - 3];
for (int i = 0; i < array.length; i++) { for (int i = 0; i < array.length; i++) {
array[i] = otherObject; array[i] = otherObject;
} }
...@@ -740,4 +754,25 @@ public class SerialFilterTest implements Serializable { ...@@ -740,4 +754,25 @@ public class SerialFilterTest implements Serializable {
return streamBytes; return streamBytes;
} }
} }
// Deeper superclass hierarchy
static class A implements Serializable {
private static final long serialVersionUID = 1L;
};
static class B extends A {
private static final long serialVersionUID = 2L;
}
static class C extends B {
private static final long serialVersionUID = 3L;
}
static class D extends C {
private static final long serialVersionUID = 4L;
}
static class E extends D {
private static final long serialVersionUID = 5L;
}
static class F extends E {
private static final long serialVersionUID = 6L;
}
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册