# Design Doc: Distributed Training Architecture
## Abstract
PaddlePaddle v0.10.0 uses the "trainer-parameter server"
architecture. We run multiple replicated instances of trainers (runs
the same code written by the user) and parameter servers for
distributed training. This architecture served us well, but has some
limitations:
1. Need to write special code to handle tasks which should only be run
by a single trainer. E.g., initializing model and saving model.
2. Model parallelism is hard: need to write if-else branches conditioned
on the trainer ID to partition model onto each trainer, and manually
write the inter-model-shard communication code.
3. The user can not directly specify the parameter update rule: need
to modify the parameter server C++ code and compile a new binary.
This adds complication for researchers: A lot of extra effort is
required. Besides, the training job submission program
may not allow running arbitrary binaries.
This design doc discusses PaddlePaddle's new distributed training
architecture that addresses the above limitations.
## Analysis
We will assume the user writes the trainer program by Python, the same
analysis holds if the trainer program is written in C++.
### Limitation 1
If we look at the Python code that the user writes, there are two
kinds of functionalities:
- The training logic such as load / save model and print log.
- The neural network definition such as the definition of the data
layer, the fully connected layer, the cost function and the
optimizer.
When we training with PaddlePaddle v0.10.0 distributedly, multiple
replicated Python instances are running on different nodes: both the
training logic and the neural network computation is replicated.
The tasks that should only run once all belong to the training logic,
if we only replicate the neural network computation but do **not**
replicate the training logic, the limitation could be solved.
### Limitation 2
Model parallelism means running a single model on multiple nodes by
partitioning the model onto different nodes and managing the
inter-model-shard communications.
PaddlePaddle should be able to modify the neural network computation
definition to support model parallelism automatically. However, the
computation is only specified in Python code, and PaddlePaddle cannot
modify Python code.
Just like compiler uses an intermediate representation (IR) so that
the programmer does not need to manually optimize their code in most of
the cases - the compiler will optimize the IR:
We can have our own IR which is called [Program](../program.md).
PaddlePaddle can support model parallel by
converting the IR so the user no longer need to manually do it in
Python:
### Limitation 3
The user can not directly specify the parameter update rule for the
parameter server because the previous implementation hard coded that
parameter server only do vector's optimization algorithm by
configuration. The user can not specify the parameter server's
computation layer by layer.
This could be fixed by making the parameter server run a separated
IR according to the trainer's variable (tensors, selectedrows)
definition.
the same
computation definition of the trainer. For a detailed explanation,
please
see
[Design Doc: Operation Graph-Based Parameter Server](./parameter_server.md)
## Distributed Training Architecture
The new distributed training architecture can address the above
limitations. Below is the illustration:
The architecture includes major components: *PaddlePaddle Python*,
*PaddlePaddle converter* and *PaddlePaddle runtime*:
### PaddlePaddle Python
PaddlePaddle Python is the Python library that user's Python trainer
invoke to build the neural network topology, start training, etc.
```Python
paddle.init()
input = paddle.op.recordIO("/home/data/mnist.recordio") # file stored on the cluster
img, label = input[0], input[1]
hidden = paddle.layer.fc(input=img, size=200, act=paddle.activation.Tanh())
prediction = paddle.layer.fc(input=img, size=10, act=paddle.activation.Softmax())
cost = paddle.layer.classification_cost(input=prediction, label=label)
optimizer = paddle.optimizer.SGD(learning_rate=0.01)
opts = optimizer.minimize(cost)
exe = RemoteExecutor(num_trainer=3, num_ps=2, GPU_per_trainer=2, sync_batches=1)
# this will init variable data on both server and trainer
exe.run(framework.default_startup_program())
exe.sync()
for i in range(1000):
# feed data
...
cost, acc = exe.run(framework.default_main_program(),
fetch_list=[avg_cost, acc_out])
print cost, acc
```
The code above is a typical Python trainer code, the neural network
topology is built using helper functions such as
`paddle.layer.fc`. The training is done by calling `Executor.run`
iteratively.
#### RemoteExecutor
As shown in the graph, `RemoteExecutor.run` sends the IR to the
PaddlePaddle cluster for Execution. You can also use parameter
`fetch_list` to interactively fetch variable back to local for
log printing.
The Python `RemoteExecutor` is derived from `Executor` class.
For more information about `RemoteExecutor`, please
see [Design Doc: RemoteExecutor](./remote_executor.md).
The `RemoteExecutor.run` interface defination is:
```python
run(self,
program=None,
feed=None,
fetch_list=None,
feed_var_name='feed',
fetch_var_name='fetch',
job_desc=JobDesc(
jobname,
num_trainer,
num_pserver,
cpu_per_trainer,
gpu_per_trainer,
mem_per_trainer,
cpu_per_pserver,
mem_per_pserver
))
```
`JobDesc` object describe the distributed job resource specification to run on
Cluster environment.
By default, `Executor.run` starts a PaddlePaddle Cloud
[TrainingJob](https://github.com/PaddlePaddle/cloud/blob/develop/doc/autoscale/README.md#training-job-resource),
or you can run each component in the
executor by your own method:
- Data Parallelism
```python
if os.getenv('PLACE_PSERVER'):
exe.run_pserver()
elif os.getenv('PLACE_TRAINER'):
exe.run_trainer()
```
- Model Parrallelism
```python
for part in exe.get_parralle_parts():
exe.run_part(part)
```
#### Program and Executor
As mentioned above, the implementation of IR is [Program](../program.md).
[Executor](../executor.md) converts and parses the IR to a preferred
graph for final execution. For local training you generally use
`Executor` to run the graph locally. For any kind of distributed
training, you can use `RemoteExecutor` to specify desired distributed
training method with some optional arguments.
### PaddlePaddle Converter
PaddlePaddle converter automatically converts the IR in the request
(IR and evaluation inputs/targets) from PaddlePaddle Python to new
partitioned IRs and dispatch the new IRs and evaluation inputs/targets
to different PaddlePaddle runtimes. Below are the steps:
1. Add `feed` OP that feeds the eval inputs, and `fetch` OP that
fetches the eval targets to the IR.
1. Extract a new computation (sub)graph with `feed` and `fetch` OP as
the boundary. The runtime does not need to run the OP that is not
dependent on the `fetch` OP.
1. Optimizes the computation graph.
1. Place the OPs in the graph onto different devices on different
PaddlePaddle runtime according to a placement algorithm and device
constraint specified by the user.
1. Partition the graph according to runtime boundaries and add `send` /
`recv` OP pair on the runtime boundaries.
1. Dispatch the partitioned graph to different PaddlePaddle runtimes.
1. PaddlePaddle runtimes with the `fetch` OP reports evaluation
results back to the converter, the convert reports the evaluation
results back to the PaddlePaddle Python.
The output IRs will be cached to optimize the conversion latency.
#### Placement Algorithm
Our first implementation will only support "trainer-parameter server"
placement: the parameters, initializers, and optimizers are placed on
the PaddlePaddle runtimes with the parameter server role. And
everything else will be placed on the PaddlePaddle runtimes with the
trainer role. This has the same functionality of our
"trainer-parameter server" architecture of PaddlePaddle v0.10.0, but
is more general and flexible.
In the future, we will implement the general placement algorithm,
which makes placements according to the input IR, and a model of
device computation time and device communication time. Model
parallelism requires the general placement algorithm.
### Local Training Architecture
The local training architecture will be the same as the distributed
training architecture, the differences are everything runs locally,
and there is just one PaddlePaddle runtime:
### Training Data
In PaddlePaddle v0.10.0, training data is typically read
with [data reader](../reader/README.md) from Python. This approach is
no longer efficient when training distributedly since the Python
process no longer runs on the same node with the trainer processes,
the Python reader will need to read from the distributed filesystem
(assuming it has the access) and send to the trainers, doubling the
network traffic.
When doing distributed training, the user can still use Python data
reader: the training data are sent with `Executor.run`. However, should
be used for debugging purpose only. The users are encouraged to use
the read data OPs.
## References:
[1] [TensorFlow: Large-Scale Machine Learning on Heterogeneous Distributed Systems](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45166.pdf)
[2] [TensorFlow: A System for Large-Scale Machine Learning](https://www.usenix.org/system/files/conference/osdi16/osdi16-abadi.pdf)