提交 4c33a100 编写于 作者: Y Yi Wang

Update distributed lookup table design doc

上级 267ffc28
## Distributed lookup table design
## Design Doc: Distributed Lookup Table Operator
A lookup table operator in PaddlePaddle where the table could be out
of the memory of a computer.
## Background
Embedding is a popular technique used in neural network to support applications such as search engines, advertising systems, and recommendation systems.
A lookup table operator is well-used in deep learning for learning the
representation, or the
[*embedding*](http://www.cs.toronto.edu/~fritz/absps/ieee-lre.pdf), of
symbols.
### The Forward Algorithm
The forward algorithm of the lookup table is a multiplication of the
input vector x and the lookup table matrix W:
Embeddings are stored in a lookup table (or hash table), that given a word id, returns the embedding (which is an array of numbers).
$$y = x * W$$
It works as below:
When x is a sparse vector of symbols, the above multiplication
simplifies into looking up rows in W that correspond to symbols in x,
denoted by W(x). Please be aware that W could be huge and out of the
memory, so we'd need a distributed storage service, which supports the
lookup of rows.
The following figure illustrates the multiplication of x with two
non-zero elements, or say, two symbols, and a lookup table W:
![lookup table](./lookup_table.png)
## Problem
The column number of the lookup_table is proportional to the range of id. In internet scale, the range of id may be very large, say 100000000000, if the size of an embedding value is 40 Byte, then the whole memory the `lookup_table` will use can be 3725.29GB:
```shell
3725.29GB = 100000000000 * 40 / 1024.0 / 1024.0 / 1024.0
```
This cannot be stored in the memory of a single machine, so we need to add a distributed lookup table that stores it in a cluster and provide the interface to get value and set value.
## Training Process
The training process with lookup table on a single machine is as follows:
![lookup table training](./lookup_table_training.png)
1. In forward pass. `lookup_table_op` convert ids into a dense tensor `ids_embedding`. `ids_embedding` will be used by the following operators.
```
lookup_table_op(lookup_table, ids) -> ids_embedding
```
1. In backward pass. `lookup_table_grad_op` convert dense tensor `ids_embedding_grad` into a tensor with id information.
```
lookup_table_grad_op(lookup_table, ids_embedding_grad) -> lookup_table_grad
```
1. In optimization pass. optimize op apply gradient to `lookup_table`.
```
optimize_op(lookup_table, lookup_table_grad) -> lookup_table
```
All the operators above access the `lookup_table` directly in memory. If we change `lookup_table` into a distributed service, all the op that will use lookup_table need to access it using some RPC calls.
## TODO
1. Implement `distributed lookup table`, with service part and client part. The client should provide four interfaces:
- `Pull(ids) -> embedding_values` pull embedding_values according to ids.
- `Push(grad, update_method)` push `grad` to the distributed lookup table and update it to the parameter using `update_method `, this interface use is asynchronous.
- `Save()` save the model to a persistent file system, such as HDFS.
- `Load()` load the model from a persistent file system, such as HDFS.
The details will be proposed in another PR.
1. Design and implement `lookup_table_op` and `lookup_table_grad_op ` with distributed lookup table client.
1. Implement the Python wrapper for above ops, users can choose and config to use these ops.
1. The distributed Fluid should support this `distributed lookup table service` on kubernetes.
1. Implement a `distributed transpiler` that can change the program into a distributed one which will use the `distributed lookup table service`.
## Things need to be discussed
In the above design, the parameter update is done within `distributed lookup table service`, the interface is `Push(grad, update_method)`, this is different than the current design of PaddlePaddle Fluid. Currently, parameter update is done by Operators. How should we impelement these update_method?
### The Backward Algorithm
The backward algorithm computes W'(x) using W(x). W'(x) has the same
scale of size as W(x) and is much smaller than W.
To optimize W given W', we can do simple SGD update:
$$W = f(W') = \lambda * W'$$
or some more sophisticated algorithms that rely on both W' and W:
$$W = f(W, W')$$
The following figure illustrates the backward pass of the lookup
operator: ![lookup table training](./lookup_table_training.png)
## Distributed Storage Service
The forward algorithm requires a distributed storage service for W.
The backward algorithm prefers that the storage system can apply the
optimization algorithm on W. The following two sections describe two
solutions -- the former doesn't require that the storage service can
do optimization, the latter does.
### Storage Service Doesn't Optimize
In this design, we use highly-optimized distributed storage, e.g.,
memcached, as the storage service, and we run the optimization
algorithm on parameter servers of PaddlePaddle. The following figure
illustrates the training process.
<img src='https://g.gravizo.com/svg?
digraph G {
rankdir="LR";
subgraph cluster1 {
P1 [label="pserver 1"];
P2 [label="pserver 2"];
T1 [label="trainer 1"];
T2 [label="trainer 2"];
T3 [label="trainer 3"];
}
KV [label="memcached"];
T1 -> P1;
T1 -> P2;
T2 -> P1;
T2 -> P2;
T3 -> P1;
T3 -> P2;
P1 -> KV [color=gray, weight=0.1];
KV -> P1 [color=gray, weight=0.1];
P2 -> KV [color=gray, weight=0.1];
KV -> P2 [color=gray, weight=0.1];
KV -> T1 [color=gray, weight=0.1];
KV -> T2 [color=gray, weight=0.1];
KV -> T3 [color=gray, weight=0.1];
}
'/>
Each trainer runs the forward and backward passes using their local
data:
1. In the forward pass, when a trainer runs the forward algorithm of a
lookup operator, it retrieves W(x) from the storage service.
1. The trainer computes W'(x) in the backward pass using W(x).
During the global update process:
1. Each trainer uploads its W'(x) to parameter servers.
1. The parameter server runs the optimization algorithm, e.g., the
Adam optimization algorithm, which requires that
1. The parameter server retrieves W(x) from memcached, and
1. The parameter server pushes $\Delta W(x)=f(W(x), lambda \sum_j
W'(x))$ to memcached, where $f$ denotes the optimization
algorithm.
### Storage Service Does Optimize
This design is very similar to the above one, except that the
optimization algorithm $f$ runs on the storage service.
- Pro: parameter servers do not retrieve W(x) from the storage
service, thus saves half network communication.
- Con: the storage service needs to be able to run the optimization
algorithm.
## Conclusion
Let us do the "storage service does not optimize" solution first, as a
baseline at least, because it is easier to use a well-optimized
distributed storage service like memcached. We can do the "storage
service does optimize" solution later or at the same time, which, if
implemented carefully, should have better performance than the former.
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册