提交 2ce080da 编写于 作者: G Geoffrey Mon 提交者: zentol

[FLINK-5183] [py] Support mulitple jobs per plan file

This closes #3232.
上级 3e767b5a
......@@ -44,8 +44,9 @@ public class PythonOperationInfo {
public String name;
public boolean usesUDF;
public int parallelism;
public int envID;
public PythonOperationInfo(PythonPlanStreamer streamer) throws IOException {
public PythonOperationInfo(PythonPlanStreamer streamer, int environmentID) throws IOException {
identifier = (String) streamer.getRecord();
parentID = (Integer) streamer.getRecord(true);
otherID = (Integer) streamer.getRecord(true);
......@@ -92,6 +93,8 @@ public class PythonOperationInfo {
values[x] = streamer.getRecord();
}
parallelism = (Integer) streamer.getRecord(true);
envID = environmentID;
}
@Override
......
......@@ -94,6 +94,7 @@ public class PythonPlanBinder {
private HashMap<Integer, Object> sets = new HashMap<>();
public ExecutionEnvironment env;
private int currentEnvironmentID = 0;
private PythonPlanStreamer streamer;
public static final int MAPPED_FILE_SIZE = 1024 * 1024 * 64;
......@@ -126,8 +127,6 @@ public class PythonPlanBinder {
}
private void runPlan(String[] args) throws Exception {
env = ExecutionEnvironment.getExecutionEnvironment();
int split = 0;
for (int x = 0; x < args.length; x++) {
if (args[x].compareTo("-") == 0) {
......@@ -139,15 +138,23 @@ public class PythonPlanBinder {
String tmpPath = FLINK_PYTHON_FILE_PATH + r.nextInt();
prepareFiles(tmpPath, Arrays.copyOfRange(args, 0, split == 0 ? args.length : split));
startPython(tmpPath, Arrays.copyOfRange(args, split == 0 ? args.length : split + 1, args.length));
receivePlan();
if (env instanceof LocalEnvironment) {
FLINK_HDFS_PATH = "file:" + System.getProperty("java.io.tmpdir") + File.separator + "flink";
// Python process should terminate itself when all jobs have been run
while (streamer.preparePlanMode()) {
receivePlan();
if (env instanceof LocalEnvironment) {
FLINK_HDFS_PATH = "file:" + System.getProperty("java.io.tmpdir") + File.separator + "flink";
}
distributeFiles(tmpPath, env);
JobExecutionResult jer = env.execute();
sendResult(jer);
streamer.finishPlanMode();
}
distributeFiles(tmpPath, env);
JobExecutionResult jer = env.execute();
sendResult(jer);
clearPath(tmpPath);
close();
} catch (Exception e) {
close();
......@@ -200,7 +207,6 @@ public class PythonPlanBinder {
clearPath(FLINK_HDFS_PATH);
FileCache.copy(new Path(tmpPath), new Path(FLINK_HDFS_PATH), true);
env.registerCachedFile(FLINK_HDFS_PATH, FLINK_PYTHON_DC_ID);
clearPath(tmpPath);
}
private void startPython(String tempPath, String[] args) throws IOException {
......@@ -234,6 +240,9 @@ public class PythonPlanBinder {
//====Plan==========================================================================================================
private void receivePlan() throws IOException {
env = ExecutionEnvironment.getExecutionEnvironment();
//IDs used in HashMap of sets are only unique for each environment
sets.clear();
receiveParameters();
receiveOperations();
}
......@@ -245,11 +254,12 @@ public class PythonPlanBinder {
private enum Parameters {
DOP,
MODE,
RETRY
RETRY,
ID
}
private void receiveParameters() throws IOException {
for (int x = 0; x < 3; x++) {
for (int x = 0; x < 4; x++) {
Tuple value = (Tuple) streamer.getRecord(true);
switch (Parameters.valueOf(((String) value.getField(0)).toUpperCase())) {
case DOP:
......@@ -263,6 +273,9 @@ public class PythonPlanBinder {
int retry = (Integer) value.getField(1);
env.setRestartStrategy(RestartStrategies.fixedDelayRestart(retry, 10000L));
break;
case ID:
currentEnvironmentID = (Integer) value.getField(1);
break;
}
}
if (env.getParallelism() < 0) {
......@@ -285,7 +298,7 @@ public class PythonPlanBinder {
private void receiveOperations() throws IOException {
Integer operationCount = (Integer) streamer.getRecord(true);
for (int x = 0; x < operationCount; x++) {
PythonOperationInfo info = new PythonOperationInfo(streamer);
PythonOperationInfo info = new PythonOperationInfo(streamer, currentEnvironmentID);
Operation op;
try {
op = Operation.valueOf(info.identifier.toUpperCase());
......@@ -518,7 +531,7 @@ public class PythonPlanBinder {
DataSet op2 = (DataSet) sets.get(info.otherID);
Keys.ExpressionKeys<?> key1 = new Keys.ExpressionKeys(info.keys1, op1.getType());
Keys.ExpressionKeys<?> key2 = new Keys.ExpressionKeys(info.keys2, op2.getType());
PythonCoGroup pcg = new PythonCoGroup(info.setID, info.types);
PythonCoGroup pcg = new PythonCoGroup(info.envID, info.setID, info.types);
sets.put(info.setID, new CoGroupRawOperator(op1, op2, key1, key2, pcg, info.types, info.name).setParallelism(getParallelism(info)));
}
......@@ -544,7 +557,7 @@ public class PythonPlanBinder {
defaultResult.setParallelism(getParallelism(info));
if (info.usesUDF) {
sets.put(info.setID, defaultResult.mapPartition(new PythonMapPartition(info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name));
sets.put(info.setID, defaultResult.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name));
} else {
sets.put(info.setID, defaultResult.name("DefaultCross"));
}
......@@ -553,13 +566,13 @@ public class PythonPlanBinder {
@SuppressWarnings("unchecked")
private void createFilterOperation(PythonOperationInfo info) {
DataSet op1 = (DataSet) sets.get(info.parentID);
sets.put(info.setID, op1.mapPartition(new PythonMapPartition(info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name));
sets.put(info.setID, op1.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name));
}
@SuppressWarnings("unchecked")
private void createFlatMapOperation(PythonOperationInfo info) {
DataSet op1 = (DataSet) sets.get(info.parentID);
sets.put(info.setID, op1.mapPartition(new PythonMapPartition(info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name));
sets.put(info.setID, op1.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name));
}
private void createGroupReduceOperation(PythonOperationInfo info) {
......@@ -580,19 +593,19 @@ public class PythonPlanBinder {
@SuppressWarnings("unchecked")
private DataSet applyGroupReduceOperation(DataSet op1, PythonOperationInfo info) {
return op1.reduceGroup(new IdentityGroupReduce()).setCombinable(false).name("PythonGroupReducePreStep").setParallelism(getParallelism(info))
.mapPartition(new PythonMapPartition(info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name);
.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name);
}
@SuppressWarnings("unchecked")
private DataSet applyGroupReduceOperation(UnsortedGrouping op1, PythonOperationInfo info) {
return op1.reduceGroup(new IdentityGroupReduce()).setCombinable(false).setParallelism(getParallelism(info)).name("PythonGroupReducePreStep")
.mapPartition(new PythonMapPartition(info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name);
.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name);
}
@SuppressWarnings("unchecked")
private DataSet applyGroupReduceOperation(SortedGrouping op1, PythonOperationInfo info) {
return op1.reduceGroup(new IdentityGroupReduce()).setCombinable(false).setParallelism(getParallelism(info)).name("PythonGroupReducePreStep")
.mapPartition(new PythonMapPartition(info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name);
.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name);
}
@SuppressWarnings("unchecked")
......@@ -602,7 +615,7 @@ public class PythonPlanBinder {
if (info.usesUDF) {
sets.put(info.setID, createDefaultJoin(op1, op2, info.keys1, info.keys2, mode, getParallelism(info))
.mapPartition(new PythonMapPartition(info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name));
.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name));
} else {
sets.put(info.setID, createDefaultJoin(op1, op2, info.keys1, info.keys2, mode, getParallelism(info)));
}
......@@ -628,13 +641,13 @@ public class PythonPlanBinder {
@SuppressWarnings("unchecked")
private void createMapOperation(PythonOperationInfo info) {
DataSet op1 = (DataSet) sets.get(info.parentID);
sets.put(info.setID, op1.mapPartition(new PythonMapPartition(info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name));
sets.put(info.setID, op1.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name));
}
@SuppressWarnings("unchecked")
private void createMapPartitionOperation(PythonOperationInfo info) {
DataSet op1 = (DataSet) sets.get(info.parentID);
sets.put(info.setID, op1.mapPartition(new PythonMapPartition(info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name));
sets.put(info.setID, op1.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name));
}
private void createReduceOperation(PythonOperationInfo info) {
......@@ -651,12 +664,12 @@ public class PythonPlanBinder {
@SuppressWarnings("unchecked")
private DataSet applyReduceOperation(DataSet op1, PythonOperationInfo info) {
return op1.reduceGroup(new IdentityGroupReduce()).setCombinable(false).setParallelism(getParallelism(info)).name("PythonReducePreStep")
.mapPartition(new PythonMapPartition(info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name);
.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name);
}
@SuppressWarnings("unchecked")
private DataSet applyReduceOperation(UnsortedGrouping op1, PythonOperationInfo info) {
return op1.reduceGroup(new IdentityGroupReduce()).setCombinable(false).setParallelism(getParallelism(info)).name("PythonReducePreStep")
.mapPartition(new PythonMapPartition(info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name);
.mapPartition(new PythonMapPartition(info.envID, info.setID, info.types)).setParallelism(getParallelism(info)).name(info.name);
}
}
......@@ -34,9 +34,9 @@ public class PythonCoGroup<IN1, IN2, OUT> extends RichCoGroupFunction<IN1, IN2,
private final PythonStreamer<IN1, IN2, OUT> streamer;
private final transient TypeInformation<OUT> typeInformation;
public PythonCoGroup(int id, TypeInformation<OUT> typeInformation) {
public PythonCoGroup(int envID, int setID, TypeInformation<OUT> typeInformation) {
this.typeInformation = typeInformation;
streamer = new PythonStreamer<>(this, id, true);
streamer = new PythonStreamer<>(this, envID, setID, true);
}
/**
......
......@@ -35,9 +35,9 @@ public class PythonMapPartition<IN, OUT> extends RichMapPartitionFunction<IN, OU
private final PythonStreamer<IN, IN, OUT> streamer;
private final transient TypeInformation<OUT> typeInformation;
public PythonMapPartition(int id, TypeInformation<OUT> typeInformation) {
public PythonMapPartition(int envId, int setId, TypeInformation<OUT> typeInformation) {
this.typeInformation = typeInformation;
streamer = new PythonStreamer<>(this, id, typeInformation instanceof PrimitiveArrayTypeInfo);
streamer = new PythonStreamer(this, envId, setId, typeInformation instanceof PrimitiveArrayTypeInfo);
}
/**
......
......@@ -56,7 +56,8 @@ public class PythonStreamer<IN1, IN2, OUT> implements Serializable {
private static final int SIGNAL_ERROR = -2;
private static final byte SIGNAL_LAST = 32;
private final int id;
private final int envID;
private final int setID;
private final boolean usePython3;
private final String planArguments;
......@@ -78,8 +79,9 @@ public class PythonStreamer<IN1, IN2, OUT> implements Serializable {
protected transient Thread outPrinter;
protected transient Thread errorPrinter;
public PythonStreamer(AbstractRichFunction function, int id, boolean usesByteArray) {
this.id = id;
public PythonStreamer(AbstractRichFunction function, int envID, int setID, boolean usesByteArray) {
this.envID = envID;
this.setID = setID;
this.usePython3 = PythonPlanBinder.usePython3;
planArguments = PythonPlanBinder.arguments.toString();
sender = new PythonSender();
......@@ -99,8 +101,8 @@ public class PythonStreamer<IN1, IN2, OUT> implements Serializable {
}
private void startPython() throws IOException {
String outputFilePath = FLINK_TMP_DATA_DIR + "/" + id + this.function.getRuntimeContext().getIndexOfThisSubtask() + "output";
String inputFilePath = FLINK_TMP_DATA_DIR + "/" + id + this.function.getRuntimeContext().getIndexOfThisSubtask() + "input";
String outputFilePath = FLINK_TMP_DATA_DIR + "/" + envID + "_" + setID + this.function.getRuntimeContext().getIndexOfThisSubtask() + "output";
String inputFilePath = FLINK_TMP_DATA_DIR + "/" + envID + "_" + setID + this.function.getRuntimeContext().getIndexOfThisSubtask() + "input";
sender.open(inputFilePath);
receiver.open(outputFilePath);
......@@ -136,8 +138,9 @@ public class PythonStreamer<IN1, IN2, OUT> implements Serializable {
OutputStream processOutput = process.getOutputStream();
processOutput.write("operator\n".getBytes(ConfigConstants.DEFAULT_CHARSET));
processOutput.write((envID + "\n").getBytes(ConfigConstants.DEFAULT_CHARSET));
processOutput.write((setID + "\n").getBytes(ConfigConstants.DEFAULT_CHARSET));
processOutput.write(("" + server.getLocalPort() + "\n").getBytes(ConfigConstants.DEFAULT_CHARSET));
processOutput.write((id + "\n").getBytes(ConfigConstants.DEFAULT_CHARSET));
processOutput.write((this.function.getRuntimeContext().getIndexOfThisSubtask() + "\n")
.getBytes(ConfigConstants.DEFAULT_CHARSET));
processOutput.write((inputFilePath + "\n").getBytes(ConfigConstants.DEFAULT_CHARSET));
......
......@@ -54,19 +54,7 @@ public class PythonPlanStreamer {
}
public void open(String tmpPath, String args) throws IOException {
server = new ServerSocket(0);
server.setSoTimeout(50);
startPython(tmpPath, args);
while (true) {
try {
socket = server.accept();
break;
} catch (SocketTimeoutException ignored) {
checkPythonProcessHealth();
}
}
sender = new PythonPlanSender(socket.getOutputStream());
receiver = new PythonPlanReceiver(socket.getInputStream());
}
private void startPython(String tmpPath, String args) throws IOException {
......@@ -82,11 +70,48 @@ public class PythonPlanStreamer {
new StreamPrinter(process.getInputStream()).start();
new StreamPrinter(process.getErrorStream()).start();
server = new ServerSocket(0);
server.setSoTimeout(50);
process.getOutputStream().write("plan\n".getBytes(ConfigConstants.DEFAULT_CHARSET));
process.getOutputStream().write((server.getLocalPort() + "\n").getBytes(ConfigConstants.DEFAULT_CHARSET));
process.getOutputStream().flush();
}
public boolean preparePlanMode() throws IOException {
try {
process.getOutputStream().write((server.getLocalPort() + "\n").getBytes(ConfigConstants.DEFAULT_CHARSET));
process.getOutputStream().flush();
} catch (IOException ignored) {
// the python process most likely shutdown in the meantime
return false;
}
while (true) {
try {
socket = server.accept();
sender = new PythonPlanSender(socket.getOutputStream());
receiver = new PythonPlanReceiver(socket.getInputStream());
return true;
} catch (SocketTimeoutException ignored) {
switch(checkPythonProcessHealth()) {
case RUNNING:
continue;
case STOPPED:
return false;
case FAILED:
throw new RuntimeException("Plan file caused an error. Check log-files for details.");
}
}
}
}
public void finishPlanMode() {
try {
socket.close();
} catch (IOException e) {
LOG.error("Failed to close socket.", e);
}
}
public void close() {
try {
process.exitValue();
......@@ -95,22 +120,29 @@ public class PythonPlanStreamer {
process.destroy();
} finally {
try {
socket.close();
server.close();
} catch (IOException e) {
LOG.error("Failed to close socket.", e);
}
}
}
private void checkPythonProcessHealth() {
private ProcessState checkPythonProcessHealth() {
try {
int value = process.exitValue();
if (value != 0) {
throw new RuntimeException("Plan file caused an error. Check log-files for details.");
return ProcessState.FAILED;
} else {
throw new RuntimeException("Plan file exited prematurely without an error.");
return ProcessState.STOPPED;
}
} catch (IllegalThreadStateException ignored) {//Process still running
return ProcessState.RUNNING;
}
}
private enum ProcessState {
RUNNING,
FAILED,
STOPPED
}
}
......@@ -27,17 +27,63 @@ import copy
import sys
from struct import pack
class EnvironmentContainer(object):
"""Keeps track of which ExecutionEnvironment is active."""
_environment_counter = 0
_environment_id_to_execute = None
_plan_mode = None
def create_environment(self):
"""Creates a new environment with a unique id."""
env = Environment(self, self._environment_counter)
self._environment_counter += 1
return env
def is_planning(self):
"""
Checks whether we are generating the plan or executing an operator.
:return: True, if the plan is generated, false otherwise
"""
if self._plan_mode is None:
mode = sys.stdin.readline().rstrip('\n')
if mode == "plan":
self._plan_mode = True
elif mode == "operator":
self._plan_mode = False
else:
raise ValueError("Invalid mode specified: " + mode)
return self._plan_mode
def should_execute(self, environment):
"""
Checks whether the given ExecutionEnvironment should run the contained plan.
:param: ExecutionEnvironment to check
:return: True, if the environment should run the contained plan, false otherise
"""
if self._environment_id_to_execute is None:
self._environment_id_to_execute = int(sys.stdin.readline().rstrip('\n'))
return environment._env_id == self._environment_id_to_execute
container = EnvironmentContainer()
def get_environment():
"""
Creates an execution environment that represents the context in which the program is currently executed.
:return:The execution environment of the context in which the program is executed.
"""
return Environment()
return container.create_environment()
class Environment(object):
def __init__(self):
def __init__(self, container, env_id):
# util
self._counter = 0
......@@ -46,6 +92,9 @@ class Environment(object):
self._local_mode = False
self._retry = 0
self._container = container
self._env_id = env_id
#sets
self._sources = []
self._sets = []
......@@ -166,9 +215,7 @@ class Environment(object):
self._local_mode = local
self._optimize_plan()
plan_mode = sys.stdin.readline().rstrip('\n') == "plan"
if plan_mode:
if self._container.is_planning():
port = int(sys.stdin.readline().rstrip('\n'))
self._connection = Connection.PureTCPConnection(port)
self._iterator = Iterator.PlanIterator(self._connection, self)
......@@ -180,31 +227,34 @@ class Environment(object):
else:
import struct
operator = None
port = None
try:
port = int(sys.stdin.readline().rstrip('\n'))
id = int(sys.stdin.readline().rstrip('\n'))
subtask_index = int(sys.stdin.readline().rstrip('\n'))
input_path = sys.stdin.readline().rstrip('\n')
output_path = sys.stdin.readline().rstrip('\n')
used_set = None
operator = None
for set in self._sets:
if set.id == id:
used_set = set
operator = set.operator
operator._configure(input_path, output_path, port, self, used_set, subtask_index)
operator._go()
operator._close()
sys.stdout.flush()
sys.stderr.flush()
if self._container.should_execute(self):
id = int(sys.stdin.readline().rstrip('\n'))
port = int(sys.stdin.readline().rstrip('\n'))
subtask_index = int(sys.stdin.readline().rstrip('\n'))
input_path = sys.stdin.readline().rstrip('\n')
output_path = sys.stdin.readline().rstrip('\n')
used_set = None
operator = None
for set in self._sets:
if set.id == id:
used_set = set
operator = set.operator
operator._configure(input_path, output_path, port, self, used_set, subtask_index)
operator._go()
operator._close()
sys.stdout.flush()
sys.stderr.flush()
except:
sys.stdout.flush()
sys.stderr.flush()
if operator is not None:
operator._connection._socket.send(struct.pack(">i", -2))
else:
elif port is not None:
socket = SOCKET.socket(family=SOCKET.AF_INET, type=SOCKET.SOCK_STREAM)
socket.connect((SOCKET.gethostbyname("localhost"), port))
socket.send(struct.pack(">i", -2))
......@@ -277,6 +327,7 @@ class Environment(object):
collect(("dop", self._dop))
collect(("mode", self._local_mode))
collect(("retry", self._retry))
collect(("id", self._env_id))
def _send_operations(self):
self._collector.collect(len(self._sources) + len(self._sets) + len(self._sinks) + len(self._broadcast))
......
# ###############################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from flink.plan.Environment import get_environment
from flink.functions.MapFunction import MapFunction
from flink.functions.CrossFunction import CrossFunction
from flink.functions.JoinFunction import JoinFunction
from flink.functions.CoGroupFunction import CoGroupFunction
from flink.functions.Aggregation import Max, Min, Sum
from utils import Verify, Verify2, Id
# Test multiple jobs in one Python plan file
if __name__ == "__main__":
env = get_environment()
env.set_parallelism(1)
d1 = env.from_elements(1, 6, 12)
d1 \
.first(1) \
.map_partition(Verify([1], "First with multiple jobs in one Python plan file")).output()
env.execute(local=True)
env2 = get_environment()
env2.set_parallelism(1)
d2 = env2.from_elements(1, 1, 12)
d2 \
.map(lambda x: x * 2) \
.map_partition(Verify([2, 2, 24], "Lambda Map with multiple jobs in one Python plan file")).output()
env2.execute(local=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册