support_new_device.md.txt 7.8 KB
Newer Older
1
# Design Doc: Supporting new Device/Library
2 3 4

## Background

5
Deep learning has a high demand for computing resources. New high-performance devices and computing libraries are appearing very frequently. Deep learning frameworks have to integrate these high-performance devices and computing libraries flexibly and efficiently.
6

7
On one hand, hardware and computing libraries usually do not have a one-to-one correspondence. For example,Intel CPUs support Eigen and MKL computing libraries while Nvidia GPUs support Eigen and cuDNN computing libraries. We have to implement operator specific kernels for each computing library.
8

9
On the other hand, users usually do not want to care about the low-level hardware and computing libraries when writing a neural network configuration. In Fluid, `Layer` is exposed in `Python`, and `Operator` is exposed in `C++`. Both `Layer` and `Operator` are hardware independent.
10 11 12 13 14 15

So, how to support a new Device/Library in Fluid becomes a challenge.


## Basic: Integrate A New Device/Library

16
For a general overview of fluid, please refer to the [overview doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/read_source.md).
17

18
There are mainly three parts that we have to consider while integrating a new device/library:
19 20 21 22 23

- Place and DeviceContext: indicates the device id and manages hardware resources

- Memory and Tensor: malloc/free data on certain device

24
- Math Functor and OpKernel: implement computing unit on certain devices/libraries
25 26 27

### Place and DeviceContext

28
Please remind that device and computing library are not one-to-one corresponding. A device can have a lot of computing libraries and a computing library can also support several devices.
29 30

#### Place
31
Fluid uses class [Place](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/place.h#L55) to represent the device memory where data is located. If we add another device, we have to add corresponding `DevicePlace`.
32 33

```
34 35
        |   CPUPlace
Place --|   CUDAPlace
36 37 38 39 40 41 42 43 44 45 46
        |   FPGAPlace
```

And `Place` is defined as follows:

```
typedef boost::variant<CUDAPlace, CPUPlace, FPGAPlace> Place;
```

#### DeviceContext

47
Fluid uses class [DeviceContext](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/device_context.h#L30) to manage the resources in different libraries, such as CUDA stream in `CDUADeviceContext`. There are also inheritance relationships between different kinds of `DeviceContext`.
48 49 50 51 52 53 54 55


```
                /->  CPUDeviceContext   --> MKLDeviceContext
DeviceContext ---->  CUDADeviceContext  --> CUDNNDeviceContext
                \->  FPGADeviceContext
```

56
An example of Nvidia GPU is as follows:
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

- DeviceContext


```
class DeviceContext {
  virtual Place GetPlace() const = 0;
};  
```


- CUDADeviceContext


```
class CUDADeviceContext : public DeviceContext {
  Place GetPlace() const override { return place_; }
private:
  CUDAPlace place_;
  cudaStream_t stream_; 
  cublasHandle_t cublas_handle_;
  std::unique_ptr<Eigen::GpuDevice> eigen_device_;  // binds with stream_
};
```

- CUDNNDeviceContext

```
class CUDNNDeviceContext : public CUDADeviceContext {
  private:
    cudnnHandle_t cudnn_handle_;
};
```


### Memory and Tensor


#### memory module

97
Fluid provides the following [memory interfaces](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/memory/memory.h#L36):
98 99 100 101 102 103 104 105 106 107 108 109

```
template <typename Place>
void* Alloc(Place place, size_t size);

template <typename Place>
void Free(Place place, void* ptr);

template <typename Place>
size_t Used(Place place);
```

110
To implement these interfaces, we have to implement MemoryAllocator for different Devices.
111 112 113 114


#### Tensor

115
[Tensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/tensor.h#L36) holds data with some shape in a specific Place.
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171

```cpp
class Tensor {
 public:
  /*! Return a pointer to mutable memory block. */
  template <typename T>
  inline T* data();

  /**
   * @brief   Return a pointer to mutable memory block.
   * @note    If not exist, then allocation.
   */
  template <typename T>
  inline T* mutable_data(platform::Place place);

  /**
   * @brief     Return a pointer to mutable memory block.
   *
   * @param[in] dims    The dimensions of the memory block.
   * @param[in] place   The place of the memory block.
   *
   * @note      If not exist, then allocation.
   */
  template <typename T>
  inline T* mutable_data(DDim dims, platform::Place place);

  /*! Resize the dimensions of the memory block. */
  inline Tensor& Resize(const DDim& dims);

  /*! Return the dimensions of the memory block. */
  inline const DDim& dims() const;

 private:
  /*! holds the memory block if allocated. */
  std::shared_ptr<Placeholder> holder_;

  /*! points to dimensions of memory block. */
  DDim dim_;
};
```

`Placeholder` is used to delay memory allocation; that is, we can first define a tensor, using `Resize` to configure its shape, and then call `mutuable_data` to allocate the actual memory.

```cpp
paddle::framework::Tensor t;
paddle::platform::CPUPlace place;
// set size first
t.Resize({2, 3});
// allocate memory on CPU later
t.mutable_data(place);
```



### Math Functor and OpKernel

172
Fluid implements computing units based on different DeviceContexts. Some computing units are shared between operators. This common part will be put in operators/math directory as basic Functors.
173 174 175 176 177 178 179 180 181 182 183 184 185 186

Let's take [MaxOutFunctor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/math/maxouting.h#L27) as an example:

The interface is defined in header file.

```
template <typename DeviceContext, typename T>
class MaxOutFunctor {
 public:
  void operator()(const DeviceContext& context, const framework::Tensor& input,
                  framework::Tensor* output, int groups);
};
```

187
CPU implemention is in .cc file
188 189 190 191 192 193 194 195 196 197 198 199 200

```
template <typename T>
class MaxOutFunctor<platform::CPUDeviceContext, T> {
  public:
  void operator()(const platform::CPUDeviceContext& context,
                  const framework::Tensor& input, framework::Tensor* output,
                  int groups) {
                  ...
                  }
};
```

201
CUDA implemention is in .cu file
202 203 204 205 206 207 208 209 210 211 212 213 214 215

```
template <typename T>
class MaxOutFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input, framework::Tensor* output,
                  int groups) {
                  ...
                  }
};                  
```


216
We get computing handle from a concrete DeviceContext, and make compution on tensors.
217

218
The implemention of `OpKernel` is similar to math functors, the extra thing we need to do is to register the OpKernel in a global map.
219

220
Fluid provides different register interfaces in op_registry.h
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243


Let's take [Crop](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/crop_op.cc#L134) operator as an example:

In .cc file:

```
REGISTER_OP_CPU_KERNEL(crop, ops::CropKernel<float>);
REGISTER_OP_CPU_KERNEL(
    crop_grad, ops::CropGradKernel<paddle::platform::CPUDeviceContext, float>);
```

In .cu file:

```
REGISTER_OP_CUDA_KERNEL(crop, ops::CropKernel<float>);
REGISTER_OP_CUDA_KERNEL(
    crop_grad, ops::CropGradKernel<paddle::platform::CUDADeviceContext, float>);
```


## Advanced topics: How to switch between different Device/Library

244
Generally, we will impelement OpKernel for all Device/Library of an Operator. We can easily train a Convolutional Neural Network in GPU. However, some OpKernel is not sutibale on a specific Device. For example, crf operator can only run on CPU, whereas most other operators can run at GPU. To achieve high performance in such circumstance, we have to switch between different Device/Library.
245 246


247
For more details, please refer to following docs:
248

249 250
- operator kernel type [doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md)
- switch kernel [doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md)