distributed_lookup_table_design.md 3.6 KB
Newer Older
W
weixing 已提交
1
# Design Doc: Distributed Lookup Table Operator
2

3
A distribute lookup table operator in PaddlePaddle where the table could be out
4
of the memory of a computer.
Q
update  
qiaolongfei 已提交
5 6 7

## Background

8 9 10 11 12 13 14 15 16
A lookup table operator is well-used in deep learning for learning the
representation, or the
[*embedding*](http://www.cs.toronto.edu/~fritz/absps/ieee-lre.pdf), of
symbols.

### The Forward Algorithm

The forward algorithm of the lookup table is a multiplication of the
input vector x and the lookup table matrix W:
Q
update  
qiaolongfei 已提交
17

18
$$y = x * W$$
Q
update  
qiaolongfei 已提交
19

20 21 22 23 24 25 26
When x is a sparse vector of symbols, the above multiplication
simplifies into looking up rows in W that correspond to symbols in x,
denoted by W(x).  Please be aware that W could be huge and out of the
memory, so we'd need a distributed storage service, which supports the
lookup of rows.

The following figure illustrates the multiplication of x with two
Q
qiaolongfei 已提交
27
non-zero elements, or say two symbols, and a lookup table W:
Q
update  
qiaolongfei 已提交
28

29
![lookup table](./src/lookup_table.png)
Q
update  
qiaolongfei 已提交
30

31 32 33
### The Backward Algorithm

The backward algorithm computes W'(x) using W(x).  W'(x) has the same
Q
qiaolongfei 已提交
34
the scale of size as W(x) and is much smaller than W.
35 36 37 38 39 40 41 42 43 44

To optimize W given W', we can do simple SGD update:

$$W = f(W') = \lambda * W'$$

or some more sophisticated algorithms that rely on both W' and W:

$$W = f(W, W')$$

The following figure illustrates the backward pass of the lookup
45
operator: ![lookup table training](./src/lookup_table_training.png)
46

47 48 49
## Distributed Lookup Table
### Problem 1: The lookup table may be very large.

Q
qiaolongfei 已提交
50
 In the condition like the search engine and recommendation system, the number of feature ID may be very large, see 1000000000, then for a lookup table of size 8, the total size of the table is:
51 52

 ```
Q
qiaolongfei 已提交
53
 100,000,000,000 * 8 * 4.0 = 2980.23 GB
54 55 56 57
 ```

### Solution: Distributed storage

Q
qiaolongfei 已提交
58
1. Paddle use SelectedRows as the storage format for the lookup table, the lookup table parameter will be split to multi-machine according to the hash of the feature ID, and data will also be split and send to the same machine to prefetch the parameter.
59

Q
qiaolongfei 已提交
60
1. For common parameters, the trainer will get the whole parameter for training, but for the big lookup table, the trainer can not store the whole parameter, but the input data feature is very sparse, so every time we only need a few parameters for training, so we use `prefetch_op` to only prefetch the parameter needed to trainer.
61 62 63

### Problem 2. The Id in the lookup table is not sure before training.

Q
qiaolongfei 已提交
64
 The feature Id is calculated by the hash function because the feature data source is so large, we can not get all the id before training. So we can not initialize the table before training.
65 66 67

### Solution: Id auto growth

Q
qiaolongfei 已提交
68
At the beginning of training, paddle only malloc the memory for the lookup table at parameter side, the id and the data will not be initialized. During training, when a parameter server received an Id, if it is already in the lookup table, it will return the existing parameter, if the id does not exist, paddle will add it into the lookup table and initialize the value for it.
69 70 71 72 73 74 75


## Architecture
The whole architecture of the distribute lookup table is as below:

### Training steps:
1. Read a batch of data, the data is feature ids.
Q
qiaolongfei 已提交
76 77 78 79
1. The input ids will be split by `split_ids_op` with the same hash function of the lookup table.
1. The `prefetch_op` use the split result to prefetch parameters back from the lookup table.
1. Run forward-backward to get the gradient of the lookup table.
1. `split_ids_op` split the gradient and then use `send_op` to the parameter server.
80 81 82
1. parameter server update the table with the received gradient.

![distribute lookup table](./src/distributed_lookup_table.jpeg)