diff --git a/doc/design/mkl/mkldnn_fluid.md b/doc/design/mkl/mkldnn_fluid.md new file mode 100644 index 0000000000000000000000000000000000000000..bef126f3f0577b69f646dfe5d10539b372c6a8a5 --- /dev/null +++ b/doc/design/mkl/mkldnn_fluid.md @@ -0,0 +1,149 @@ +# Design Doc: Add MKLDNN Kernel in Fluid Operator + +## Principles + +First of all, we should follow some basical principles like: +1. [How to write a new operator](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_en.md). We are trying to add a new kind of kernel into operators, so basically we should follow this doc. +2. [Supporting new Device/Library](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/support_new_device.md). Since MKLDNN is a new library to fluid, we should add `MKLDNNDeviceContext` and maybe `mkldnn_helper.h`, just like [cudnn_helper.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/cudnn_helper.h). +3. [Switch Kernel](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md). Another important point is that we should ensure the data synchronization between different kernel types, which is this [topic](https://github.com/PaddlePaddle/Paddle/issues/6549). So basically we should override `GetExpectedKernelType` and `trans` functions to support switching kernels. +4. [The Keys of Operator Kernel Type](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md). Kernel Type is a pivotal conception which can record the `Place`, `Library`, `DataType` and `Layout`. + +## Sulution + +In general, there are four parts we should follow to run a MKL-DNN primitive. +- Create a primitive descriptor that describe this operator +- Create a primitive itself by primitive descriptor and the engine +- Create all memory buffers that primitive needed +- Launch a stream to execute the primitive created +More details can refer to [here](http://01org.github.io/mkl-dnn). + +It's better to avoid reinitialization of primitives and memory handles in the first three stages in every iteration. \ +So we plan to create a map to record all the `primitive` and `memory`, which should not take too much memories as discussed [here](https://github.com/PaddlePaddle/Paddle/issues/6822). + +It's assumed that following three conditions should be satisfied. +1. there is a unique key for each operator instance. May be the actual name of `Output Tensor`. +2. the `Input Tensor` inside `Compute` function is the one after converted. +3. we can get the phase(eg. `is_test`) inside `Compute` function, otherwise we need to expose this attribue to user. + +### Compute +The algorithm of `Compute` would be described as follow, let's take conv like an example. + +```c++ + + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); + PADDLE_ENFORCE(platform::is_mkldnn_library(ctx.GetLibrary()), "It must use MKLDNN Library."); + + auto& dev_ctx = ctx.template device_context(); + + // find primitive by unique key from mkldnn context + // the op_key should be a unique name of this op instance + auto& p = dev_ctx.findPrimitive(op_key + "_fwd"); + + // assuming the input tensor inside this compute function is the one after converted + // this point should be guarantee by another mechanism + auto& i = dev_ctx.findMemory(op_key + "_input"); + + if (p == nullptr || i == nullptr || inputSizeChanged(p, i)) { + auto fwd_primitive_desc = createPrimitiveDesc(ctx); + auto* input = ctx.Input("Input"); + auto* filter = ctx.Input("Filter"); + auto* output = ctx.Output("Output"); + shared_ptr in(new mkldnn::memory(fwd_primitive_desc->src_primitive_desc(), input->data())); + shared_ptr wgt(new mkldnn::memory(fwd_primitive_desc->weights_primitive_desc(), filter->data())); + shared_ptr out(new mkldnn::memory(fwd_primitive_desc->dst_primitive_desc(), output->mutable_data(ctx.GetPlace()))); + shared_ptr fwd_primitive(new mkldnn::conv_fwd(*fwd_primitive_desc, *in, *wgt, *out)); + + dev_ctx.addMemory(op_key+"_input", in); + dev_ctx.addMemory(op_key+"_output", out); + dev_ctx.addMemory(op_key+"_filer", wgt); + dev_ctx.addPrimitive(op_key+"_fwd", fwd_primitive); + dev_ctx.addPrimitiveDesc(op_key+"_fwd_PD", fwd_primitive_desc); + } + + p = dev_ctx.findPrimitive(op_key + "_fwd"); + + PADDLE_ENFORCE(p, "Should have forward Primitive"); + PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_input"), "Should have input memory"); + PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_output"), "Should have output memory"); + PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_filter"), "Should have filter memory"); + PADDLE_ENFORCE(dev_ctx.findPrimitiveDesc(op_unique_key+"_fwd_PD"), "Should have forward PrimitiveDesc"); + dev_ctx.submit(p); + dev_ctx.execute(); // the convert primitive should have already contained. + +``` + +The `createPrimitiveDesc` returns the primitive descripotor of this operator, would be like this: +```c++ + auto* input = ctx.Input("Input"); + auto* filter = ctx.Input("Filter"); + auto* output = ctx.Output("Output"); + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + int groups = ctx.Attr("groups"); + algorithm algo = static_cast(ctx.Attr("convolution_algorithm_option")); + prop_kind pk = ctx.Attr("is_test") ? prop_kind::forward_inference : prop_kind::forward_training; + + auto fwd_desc = mkldnn::conv_fwd::desc(/* all the setting above*/); + shared_ptr fwd_primitive_desc(new mkldnn::conv_fwd::primitive_desc(fwd_desc, ctx.getEngine())); + + return fwd_primitive_desc; + } +``` + +### MKLDNNDeviceContext +`MKLDNNDeviceContext`, which is very straightforward, should contain some base information like: `stream`, `engine` and the map needed. + + +### mkldnn_helper +Some functions would be put in `paddle/platform/mkldnn_helper.h`. +- create MKLDNN memories +- create MKLDNN primitives +- error check function +- etc + + +### Kernel Switch +We should `reorder` the different Layout from other device or to other device. `GetExpectedKernelType` and `trans` functions can help us to implement it. + +`GetExpectedKernelType` should get the context, and this operator can return the best `KernelType`. +`trans` would be like this: + +```c++ +void trans(inputs, ctx) override { + if (NoNeedTrans()) { + return; + } + // find reorder primitive by op_key from context + auto& dev_ctx = ctx.template device_context(); + auto& p = dev_ctx.findPrimitive(op_key + "_reorder_input"); + auto& i = dev_ctx.findMemory(op_key + "_src_input"); + + if (p == nullptr || i == nullptr || changeSized(i, input)) { + auto prim = createPrimitiveDesc(ctx); + auto src = createMemory(memoryDesc(input->dims(), actual_layout), input->data); + auto newbuffer = paddle::memory::Alloc(ctx.GetPlace(), input->size_in_bytes()); + auto dst = createMemory(p->expected_desc(), newbuffer->data); + auto reorder_primitive(new mkldnn::reorder(src, dst)); + + dev_ctx.addMemory(op_key+"_src_input", src); + dev_ctx.addMemory(op_key+"_input", dst); + dev_ctx.addPrimitive(op_key+"_reorder_input", reorder_primitive); + } + + p = dev_ctx.findPrimitive(op_key + "_reorder_input"); + PADDLE_ENFORCE(p, "Should have Reorder Primitive"); + dev_ctx.submit(p); + if (! this->isMKLDNNKernel()) { + // execute immediately only if this is not mkldnn kernel function. + // otherwise, it can be executed with the operator primitive in Compute + dev_ctx.stream(); + } + // after submit, the input tensor in ExecutionContext should be changed as the converted one + // there should be another mechanism to ensure this +} +``` + +### Unit Test +All the functions should be tested corresponding. +TBD