提交 90b3b4cf 编写于 作者: T Till Rohrmann 提交者: zentol

[FLINK-7943] Make ParameterTool thread safe

This commit changes the serialization of the ParameterTool such that only the
data map is contained. The defaultData and the unrequestedParameters maps are
not serialized because they are only used on the client side. Additionally, the
defaultData and unrequestedParameters map are being made thread safe by using
ConcurrentHashMaps.

This closes #4921.
上级 ec863708
......@@ -24,6 +24,7 @@ import org.apache.flink.api.java.typeutils.GenericTypeInfo;
import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer;
import org.apache.flink.core.testutils.CommonTestUtils;
import org.apache.flink.util.SerializedValue;
import org.apache.flink.util.TestLogger;
import org.junit.Test;
......@@ -37,7 +38,7 @@ import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
public class ExecutionConfigTest {
public class ExecutionConfigTest extends TestLogger {
@Test
public void testDoubleTypeRegistration() {
......
......@@ -32,6 +32,7 @@ import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.Arrays;
......@@ -39,8 +40,10 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
/**
* This class provides simple utility methods for reading and parsing program arguments from different sources.
......@@ -212,13 +215,38 @@ public class ParameterTool extends ExecutionConfig.GlobalJobParameters implement
// ------------------ ParameterUtil ------------------------
protected final Map<String, String> data;
protected final Map<String, String> defaultData;
protected final Set<String> unrequestedParameters;
// data which is only used on the client and does not need to be transmitted
protected transient Map<String, String> defaultData;
protected transient Set<String> unrequestedParameters;
private ParameterTool(Map<String, String> data) {
this.data = new HashMap<>(data);
this.defaultData = new HashMap<>();
this.unrequestedParameters = new HashSet<>(data.keySet());
this.data = Collections.unmodifiableMap(new HashMap<>(data));
this.defaultData = new ConcurrentHashMap<>(data.size());
this.unrequestedParameters = Collections.newSetFromMap(new ConcurrentHashMap<>(data.size()));
unrequestedParameters.addAll(data.keySet());
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ParameterTool that = (ParameterTool) o;
return Objects.equals(data, that.data) &&
Objects.equals(defaultData, that.defaultData) &&
Objects.equals(unrequestedParameters, that.unrequestedParameters);
}
@Override
public int hashCode() {
return Objects.hash(data, defaultData, unrequestedParameters);
}
/**
......@@ -560,9 +588,21 @@ public class ParameterTool extends ExecutionConfig.GlobalJobParameters implement
* @return The Merged {@link ParameterTool}
*/
public ParameterTool mergeWith(ParameterTool other) {
ParameterTool ret = new ParameterTool(this.data);
ret.data.putAll(other.data);
ret.unrequestedParameters.addAll(other.unrequestedParameters);
Map<String, String> resultData = new HashMap<>(data.size() + other.data.size());
resultData.putAll(data);
resultData.putAll(other.data);
ParameterTool ret = new ParameterTool(resultData);
final HashSet<String> requestedParametersLeft = new HashSet<>(data.keySet());
requestedParametersLeft.removeAll(unrequestedParameters);
final HashSet<String> requestedParametersRight = new HashSet<>(other.data.keySet());
requestedParametersRight.removeAll(other.unrequestedParameters);
ret.unrequestedParameters.removeAll(requestedParametersLeft);
ret.unrequestedParameters.removeAll(requestedParametersRight);
return ret;
}
......@@ -573,4 +613,12 @@ public class ParameterTool extends ExecutionConfig.GlobalJobParameters implement
return data;
}
// ------------------------- Serialization ---------------------------------------------
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
in.defaultReadObject();
defaultData = Collections.emptyMap();
unrequestedParameters = Collections.emptySet();
}
}
......@@ -83,24 +83,28 @@ public class RequiredParameters {
* <p>If any check fails, a RequiredParametersException is thrown
*
* @param parameterTool - parameters supplied by the user.
* @return the updated ParameterTool containing all the required parameters
* @throws RequiredParametersException if any of the specified checks fail
*/
public void applyTo(ParameterTool parameterTool) throws RequiredParametersException {
public ParameterTool applyTo(ParameterTool parameterTool) throws RequiredParametersException {
List<String> missingArguments = new LinkedList<>();
HashMap<String, String> newParameters = new HashMap<>(parameterTool.toMap());
for (Option o : data.values()) {
if (parameterTool.data.containsKey(o.getName())) {
if (Objects.equals(parameterTool.data.get(o.getName()), ParameterTool.NO_VALUE_KEY)) {
if (newParameters.containsKey(o.getName())) {
if (Objects.equals(newParameters.get(o.getName()), ParameterTool.NO_VALUE_KEY)) {
// the parameter has been passed, but no value, check if there is a default value
checkAndApplyDefaultValue(o, parameterTool.data);
checkAndApplyDefaultValue(o, newParameters);
} else {
// a value has been passed in the parameterTool, now check if it adheres to all constraints
checkAmbiguousValues(o, parameterTool.data);
checkIsCastableToDefinedType(o, parameterTool.data);
checkChoices(o, parameterTool.data);
checkAmbiguousValues(o, newParameters);
checkIsCastableToDefinedType(o, newParameters);
checkChoices(o, newParameters);
}
} else {
// check if there is a default name or a value passed for a possibly defined alternative name.
if (hasNoDefaultValueAndNoValuePassedOnAlternativeName(o, parameterTool.data)) {
if (hasNoDefaultValueAndNoValuePassedOnAlternativeName(o, newParameters)) {
missingArguments.add(o.getName());
}
}
......@@ -108,6 +112,8 @@ public class RequiredParameters {
if (!missingArguments.isEmpty()) {
throw new RequiredParametersException(this.missingArgumentsText(missingArguments), missingArguments);
}
return ParameterTool.fromMap(newParameters);
}
// check if the given parameter has a default value and add it to the passed map if that is the case
......
......@@ -23,16 +23,29 @@ import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
/**
* Tests for {@link ParameterTool}.
......@@ -574,6 +587,78 @@ public class ParameterToolTest extends AbstractParameterToolTest {
Assert.assertEquals(Collections.emptySet(), parameter.getUnrequestedParameters());
}
/**
* Tests that we can concurrently serialize and access the ParameterTool. See FLINK-7943
*/
@Test
public void testConcurrentExecutionConfigSerialization() throws ExecutionException, InterruptedException {
final int numInputs = 10;
Collection<String> input = new ArrayList<>(numInputs);
for (int i = 0; i < numInputs; i++) {
input.add("--" + UUID.randomUUID());
input.add(UUID.randomUUID().toString());
}
final String[] args = input.toArray(new String[0]);
final ParameterTool parameterTool = ParameterTool.fromArgs(args);
final int numThreads = 5;
final int numSerializations = 100;
final Collection<CompletableFuture<Void>> futures = new ArrayList<>(numSerializations);
final ExecutorService executorService = Executors.newFixedThreadPool(numThreads);
try {
for (int i = 0; i < numSerializations; i++) {
futures.add(
CompletableFuture.runAsync(
() -> {
try {
serializeDeserialize(parameterTool);
} catch (Exception e) {
throw new CompletionException(e);
}
},
executorService));
}
for (CompletableFuture<Void> future : futures) {
future.get();
}
} finally {
executorService.shutdownNow();
executorService.awaitTermination(1000L, TimeUnit.MILLISECONDS);
}
}
/**
* Accesses parameter tool parameters and then serializes the given parameter tool and deserializes again.
* @param parameterTool to serialize/deserialize
*/
private void serializeDeserialize(ParameterTool parameterTool) throws IOException, ClassNotFoundException {
// weirdly enough, this call has side effects making the ParameterTool serialization fail if not
// using a concurrent data structure.
parameterTool.get(UUID.randomUUID().toString());
try (
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos)) {
oos.writeObject(parameterTool);
oos.close();
baos.close();
ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
ObjectInputStream ois = new ObjectInputStream(bais);
// this should work :-)
ParameterTool deserializedParameterTool = ((ParameterTool) ois.readObject());
}
}
private static <T> Set<T> createHashSet(T... elements) {
Set<T> set = new HashSet<>();
for (T element : elements) {
......
......@@ -18,6 +18,8 @@
package org.apache.flink.api.java.utils;
import org.apache.flink.util.TestLogger;
import org.hamcrest.CoreMatchers;
import org.junit.Assert;
import org.junit.Rule;
......@@ -33,7 +35,7 @@ import static org.junit.Assert.fail;
/**
* Tests for RequiredParameter class and its interactions with ParameterTool.
*/
public class RequiredParametersTest {
public class RequiredParametersTest extends TestLogger {
@Rule
public ExpectedException expectedException = ExpectedException.none();
......@@ -122,7 +124,7 @@ public class RequiredParametersTest {
try {
required.add(new Option("berlin").alt("b"));
required.applyTo(parameter);
parameter = required.applyTo(parameter);
Assert.assertEquals(parameter.data.get("berlin"), "value");
Assert.assertEquals(parameter.data.get("b"), "value");
} catch (RequiredParametersException e) {
......@@ -137,7 +139,7 @@ public class RequiredParametersTest {
try {
required.add(new Option("berlin").alt("b").defaultValue("something"));
required.applyTo(parameter);
parameter = required.applyTo(parameter);
Assert.assertEquals(parameter.data.get("berlin"), "value");
Assert.assertEquals(parameter.data.get("b"), "value");
} catch (RequiredParametersException e) {
......@@ -164,7 +166,7 @@ public class RequiredParametersTest {
RequiredParameters required = new RequiredParameters();
try {
required.add(new Option("berlin"));
required.applyTo(parameter);
parameter = required.applyTo(parameter);
Assert.assertEquals(parameter.data.get("berlin"), "value");
} catch (RequiredParametersException e) {
fail("Exception thrown " + e.getMessage());
......@@ -177,7 +179,7 @@ public class RequiredParametersTest {
RequiredParameters required = new RequiredParameters();
try {
required.add(new Option("berlin").defaultValue("value"));
required.applyTo(parameter);
parameter = required.applyTo(parameter);
Assert.assertEquals(parameter.data.get("berlin"), "value");
} catch (RequiredParametersException e) {
fail("Exception thrown " + e.getMessage());
......@@ -190,7 +192,7 @@ public class RequiredParametersTest {
RequiredParameters required = new RequiredParameters();
try {
required.add(new Option("berlin").alt("b").defaultValue("value"));
required.applyTo(parameter);
parameter = required.applyTo(parameter);
Assert.assertEquals(parameter.data.get("berlin"), "value");
Assert.assertEquals(parameter.data.get("b"), "value");
} catch (RequiredParametersException e) {
......@@ -205,7 +207,7 @@ public class RequiredParametersTest {
try {
rq.add("input");
rq.add(new Option("parallelism").alt("p").defaultValue("1").type(OptionType.INTEGER));
rq.applyTo(parameter);
parameter = rq.applyTo(parameter);
Assert.assertEquals(parameter.data.get("parallelism"), "1");
Assert.assertEquals(parameter.data.get("p"), "1");
Assert.assertEquals(parameter.data.get("input"), "abc");
......@@ -223,7 +225,7 @@ public class RequiredParametersTest {
required.add(new Option("count").defaultValue("15"));
required.add(new Option("someFlag").alt("sf").defaultValue("true"));
required.applyTo(parameter);
parameter = required.applyTo(parameter);
Assert.assertEquals(parameter.data.get("berlin"), "value");
Assert.assertEquals(parameter.data.get("count"), "15");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册