First of all, we should follow some basical principles like:
First of all, we should follow some basical principles like:
1.[How to write a new operator](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_en.md). We are trying to add a new kind of kernel into operators, so basically we should follow this doc.
1.[How to write a new operator](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_en.md). We are trying to add a new kind of kernel into operators, so basically we should follow this doc.
2.[Supporting new Device/Library](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/support_new_device.md). Since MKLDNN is a new library to fluid, we should add `MKLDNNDeviceContext` and maybe `mkldnn_helper.h`, just like [cudnn_helper.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/cudnn_helper.h).
2.[Supporting new Device/Library](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/support_new_device.md). Since MKLDNN is a new library to fluid, we should add `MKLDNNDeviceContext` and maybe `mkldnn_helper.h`, just like [cudnn_helper.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/cudnn_helper.h).
3.[Switch Kernel](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md). Another important point is that we should ensure the data synchronization among different divices, which is the topic #6549. So basically we should override `GetActualKernelType` and `trans` functions to support switching kernels.
3.[Switch Kernel](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md). Another important point is that we should ensure the data synchronization between different kernel types, which is this [topic](https://github.com/PaddlePaddle/Paddle/issues/6549). So basically we should override `GetExpectedKernelType` and `trans` functions to support switching kernels.
4.[The Keys of Operator Kernel Type](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md). Kernel Type is a pivotal conception which can record the `Place`, `Library`, `DataType` and `Layout`.
4.[The Keys of Operator Kernel Type](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md). Kernel Type is a pivotal conception which can record the `Place`, `Library`, `DataType` and `Layout`.
## Sulution
## Sulution
In general, there are three part we should follow to run a MKL-DNN primitive.
In general, there are four parts we should follow to run a MKL-DNN primitive.
- create a primitive descriptor that describe this operator
- Create a primitive descriptor that describe this operator
- create a memory that handle all memory buffers needed
- Create a primitive itself by primitive descriptor and the engine
- stream the primitive create by first two
- Create all memory buffers that primitive needed
- Launch a stream to execute the primitive created
More details can refer to [here](http://01org.github.io/mkl-dnn).
We do not want to see the first two would be re-initialized every iteration again and again. \
It's better to avoid reinitialization of primitives and memory handles in the first three stages in every iteration. \
So we plan to create a map to record all the `primitive` and `memory`, which should not take too much memories as discussed [here](https://github.com/PaddlePaddle/Paddle/issues/6822).
So we plan to create a map to record all the `primitive` and `memory`, which should not take too much memories as discussed [here](https://github.com/PaddlePaddle/Paddle/issues/6822).
Assuming that three condition would be confirmed:
It's assumed that following three conditions should be satisfied.
1. there is a unique key for each operator instance. May be the actual name of `Output Tensor`.
1. there is a unique key for each operator instance. May be the actual name of `Output Tensor`.
2. the `Input Tensor` inside `Compute` function is the one after converted.
2. the `Input Tensor` inside `Compute` function is the one after converted.
3. we can get the phase(eg. `is_test`) inside `Compute` function, otherwise we need to expose this attribue to user.
3. we can get the phase(eg. `is_test`) inside `Compute` function, otherwise we need to expose this attribue to user.
...
@@ -32,44 +33,42 @@ The algorithm of `Compute` would be described as follow, let's take conv like an
...
@@ -32,44 +33,42 @@ The algorithm of `Compute` would be described as follow, let's take conv like an
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),"It must use CPUPlace.");
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),"It must use CPUPlace.");
PADDLE_ENFORCE(platform::is_mkldnn_library(ctx.GetLibrary()),"It must use MKLDNN Library.");
PADDLE_ENFORCE(platform::is_mkldnn_library(ctx.GetLibrary()),"It must use MKLDNN Library.");