“9314743d0cf59ca98895c05d6ce3774e75de27ba”上不存在“paddle/fluid/git@gitcode.net:BaiXuePrincess/Paddle.git”
distributed_lookup_table_design.md 6.1 KB
Newer Older
W
weixing 已提交
1
# Design Doc: Distributed Lookup Table Operator
2 3 4

A lookup table operator in PaddlePaddle where the table could be out
of the memory of a computer.
Q
update  
qiaolongfei 已提交
5 6 7

## Background

8 9 10 11 12 13 14 15 16
A lookup table operator is well-used in deep learning for learning the
representation, or the
[*embedding*](http://www.cs.toronto.edu/~fritz/absps/ieee-lre.pdf), of
symbols.

### The Forward Algorithm

The forward algorithm of the lookup table is a multiplication of the
input vector x and the lookup table matrix W:
Q
update  
qiaolongfei 已提交
17

18
$$y = x * W$$
Q
update  
qiaolongfei 已提交
19

20 21 22 23 24 25 26 27
When x is a sparse vector of symbols, the above multiplication
simplifies into looking up rows in W that correspond to symbols in x,
denoted by W(x).  Please be aware that W could be huge and out of the
memory, so we'd need a distributed storage service, which supports the
lookup of rows.

The following figure illustrates the multiplication of x with two
non-zero elements, or say, two symbols, and a lookup table W:
Q
update  
qiaolongfei 已提交
28

29
![lookup table](./src/lookup_table.png)
Q
update  
qiaolongfei 已提交
30

31 32 33 34 35 36 37 38 39 40 41 42 43 44
### The Backward Algorithm

The backward algorithm computes W'(x) using W(x).  W'(x) has the same
scale of size as W(x) and is much smaller than W.

To optimize W given W', we can do simple SGD update:

$$W = f(W') = \lambda * W'$$

or some more sophisticated algorithms that rely on both W' and W:

$$W = f(W, W')$$

The following figure illustrates the backward pass of the lookup
45
operator: ![lookup table training](./src/lookup_table_training.png)
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61

## Distributed Storage Service

The forward algorithm requires a distributed storage service for W.
The backward algorithm prefers that the storage system can apply the
optimization algorithm on W.  The following two sections describe two
solutions -- the former doesn't require that the storage service can
do optimization, the latter does.

### Storage Service Doesn't Optimize

In this design, we use highly-optimized distributed storage, e.g.,
memcached, as the storage service, and we run the optimization
algorithm on parameter servers of PaddlePaddle.  The following figure
illustrates the training process.

Q
qiaolongfei 已提交
62 63 64
<!--
Note: please update the following URL when update this digraph.
<img src='https://g.gravizo.com/svg?
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
digraph G {
  rankdir="LR";
  subgraph cluster1 {
  P1 [label="pserver 1"];
  P2 [label="pserver 2"];
  T1 [label="trainer 1"];
  T2 [label="trainer 2"];
  T3 [label="trainer 3"];
  }
  KV [label="memcached"];
  T1 -> P1;
  T1 -> P2;
  T2 -> P1;
  T2 -> P2;
  T3 -> P1;
  T3 -> P2;
  P1 -> KV [color=gray, weight=0.1];
  KV -> P1 [color=gray, weight=0.1];
  P2 -> KV [color=gray, weight=0.1];
  KV -> P2 [color=gray, weight=0.1];
  KV -> T1 [color=gray, weight=0.1];
  KV -> T2 [color=gray, weight=0.1];
  KV -> T3 [color=gray, weight=0.1];
}
Q
qiaolongfei 已提交
89
)
Q
qiaolongfei 已提交
90 91 92 93
'/>
-->

<img src='https://g.gravizo.com/svg?%20digraph%20G%20{%20rankdir=%22LR%22;%20subgraph%20cluster1%20{%20P1%20[label=%22pserver%201%22];%20P2%20[label=%22pserver%202%22];%20T1%20[label=%22trainer%201%22];%20T2%20[label=%22trainer%202%22];%20T3%20[label=%22trainer%203%22];%20}%20KV%20[label=%22memcached%22];%20T1%20-%3E%20P1;%20T1%20-%3E%20P2;%20T2%20-%3E%20P1;%20T2%20-%3E%20P2;%20T3%20-%3E%20P1;%20T3%20-%3E%20P2;%20P1%20-%3E%20KV%20[color=gray,%20weight=0.1];%20KV%20-%3E%20P1%20[color=gray,%20weight=0.1];%20P2%20-%3E%20KV%20[color=gray,%20weight=0.1];%20KV%20-%3E%20P2%20[color=gray,%20weight=0.1];%20KV%20-%3E%20T1%20[color=gray,%20weight=0.1];%20KV%20-%3E%20T2%20[color=gray,%20weight=0.1];%20KV%20-%3E%20T3%20[color=gray,%20weight=0.1];%20}'/>
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121

Each trainer runs the forward and backward passes using their local
data:

1. In the forward pass, when a trainer runs the forward algorithm of a
   lookup operator, it retrieves W(x) from the storage service.
1. The trainer computes W'(x) in the backward pass using W(x).

During the global update process:

1. Each trainer uploads its W'(x) to parameter servers.
1. The parameter server runs the optimization algorithm, e.g., the
   Adam optimization algorithm, which requires that
   1. The parameter server retrieves W(x) from memcached, and
   1. The parameter server pushes $\Delta W(x)=f(W(x), lambda \sum_j
      W'(x))$ to memcached, where $f$ denotes the optimization
      algorithm.

### Storage Service Does Optimize

This design is very similar to the above one, except that the
optimization algorithm $f$ runs on the storage service.

- Pro: parameter servers do not retrieve W(x) from the storage
  service, thus saves half network communication.
- Con: the storage service needs to be able to run the optimization
  algorithm.

122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
## Distributed Sparse Table in Fluid

For another design, we can implement a distributed sparse table in Fluid,
and don't need to maintain an external storage component while training.

Prior to reading this design, it would be useful for the reader to make themselves
familiar with Fluid [Distributed Training Architecture](./distributed_architecture.md)
and [Parameter Server](./parameter_server.md).

![fluid lookup remote table](./src/fluid_lookup_remote_table.png)

Partition a large table into multiple pserver instances
1. `DistributeTranspiler` would split the table partitioned into some small
table blocks with some partitioned algorithms such as
[RoundRobin](https://en.wikipedia.org/wiki/Round-robin_scheduling),
[Hash](https://en.wikipedia.org/wiki/Hash) and etc...
1. For some cases, the range of input `Ids` is very wide and unpredictable, so the sparse
table would be able to fill a new value for the id that didn't appear before with
zero, uniform random or Gaussian distribution.

For each Trainer's training process:
1. In the forward pass, we use `pre-fetch` op to pre-fetch parameter blocks according to the
input `Ids` from PServers instead of the local `lookup_table` op, and then merge the blocks
into a parameter `W`.
1. Compute `GRAD@W'` in the backward pass using the pre-fetched `W` and send it to PServer to
execute the optimize pass.

149 150 151 152 153 154 155
## Conclusion

Let us do the "storage service does not optimize" solution first, as a
baseline at least, because it is easier to use a well-optimized
distributed storage service like memcached.  We can do the "storage
service does optimize" solution later or at the same time, which, if
implemented carefully, should have better performance than the former.