提交 da542266 编写于 作者: T tensor-tang

follow comments and refine

上级 661f03b4
# Design Doc: Add MKLDNN Kernel in Fluid Operator # Design Doc: Add MKLDNN Kernel in Fluid Operator
## Principles ## Principles
First of all, we should follow some basical principles like: First of all, we should follow some basical principles like:
1. [How to write a new operator](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_en.md). We are trying to add a new kind of kernel into operators, so basically we should follow this doc. 1. [How to write a new operator](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_en.md). We are trying to add a new kind of kernel into operators, so basically we should follow this doc.
2. [Supporting new Device/Library](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/support_new_device.md). Since MKLDNN is a new library to fluid, we should add `MKLDNNDeviceContext` and maybe `mkldnn_helper.h`, just like [cudnn_helper.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/cudnn_helper.h). 2. [Supporting new Device/Library](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/support_new_device.md). Since MKLDNN is a new library to fluid, we should add `MKLDNNDeviceContext` and maybe `mkldnn_helper.h`, just like [cudnn_helper.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/cudnn_helper.h).
3. [Switch Kernel](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md). Another important point is that we should ensure the data synchronization among different divices, which is the topic #6549. So basically we should override `GetActualKernelType` and `trans` functions to support switching kernels. 3. [Switch Kernel](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md). Another important point is that we should ensure the data synchronization between different kernel types, which is this [topic](https://github.com/PaddlePaddle/Paddle/issues/6549). So basically we should override `GetExpectedKernelType` and `trans` functions to support switching kernels.
4. [The Keys of Operator Kernel Type](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md). Kernel Type is a pivotal conception which can record the `Place`, `Library`, `DataType` and `Layout`. 4. [The Keys of Operator Kernel Type](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md). Kernel Type is a pivotal conception which can record the `Place`, `Library`, `DataType` and `Layout`.
## Sulution ## Sulution
In general, there are three part we should follow to run a MKL-DNN primitive. In general, there are four parts we should follow to run a MKL-DNN primitive.
- create a primitive descriptor that describe this operator - Create a primitive descriptor that describe this operator
- create a memory that handle all memory buffers needed - Create a primitive itself by primitive descriptor and the engine
- stream the primitive create by first two - Create all memory buffers that primitive needed
- Launch a stream to execute the primitive created
More details can refer to [here](http://01org.github.io/mkl-dnn).
We do not want to see the first two would be re-initialized every iteration again and again. \ It's better to avoid reinitialization of primitives and memory handles in the first three stages in every iteration. \
So we plan to create a map to record all the `primitive` and `memory`, which should not take too much memories as discussed [here](https://github.com/PaddlePaddle/Paddle/issues/6822). So we plan to create a map to record all the `primitive` and `memory`, which should not take too much memories as discussed [here](https://github.com/PaddlePaddle/Paddle/issues/6822).
Assuming that three condition would be confirmed: It's assumed that following three conditions should be satisfied.
1. there is a unique key for each operator instance. May be the actual name of `Output Tensor`. 1. there is a unique key for each operator instance. May be the actual name of `Output Tensor`.
2. the `Input Tensor` inside `Compute` function is the one after converted. 2. the `Input Tensor` inside `Compute` function is the one after converted.
3. we can get the phase(eg. `is_test`) inside `Compute` function, otherwise we need to expose this attribue to user. 3. we can get the phase(eg. `is_test`) inside `Compute` function, otherwise we need to expose this attribue to user.
...@@ -32,44 +33,42 @@ The algorithm of `Compute` would be described as follow, let's take conv like an ...@@ -32,44 +33,42 @@ The algorithm of `Compute` would be described as follow, let's take conv like an
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace.");
PADDLE_ENFORCE(platform::is_mkldnn_library(ctx.GetLibrary()), "It must use MKLDNN Library."); PADDLE_ENFORCE(platform::is_mkldnn_library(ctx.GetLibrary()), "It must use MKLDNN Library.");
auto& ctx = executionContext.template device_context<platform::MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::MKLDNNDeviceContext>();
// find primitive by unique key from mkldnn context // find primitive by unique key from mkldnn context
// the op_key should be a unique name of this op instance // the op_key should be a unique name of this op instance
auto& p = ctx.findPrimitive(op_key + "_fwd"); auto& p = dev_ctx.findPrimitive(op_key + "_fwd");
// assuming the input tensor inside this compute function is the one after converted // assuming the input tensor inside this compute function is the one after converted
// this point should be guarantee by another mechanism // this point should be guarantee by another mechanism
auto& i = ctx.findMemory(op_key + "_input"); auto& i = dev_ctx.findMemory(op_key + "_input");
if (p == nullptr || i == nullptr || inputSizeChanged(p, i)) { if (p == nullptr || i == nullptr || inputSizeChanged(p, i)) {
auto fwd_primitive_desc = createPrimitiveDesc(ctx); auto fwd_primitive_desc = createPrimitiveDesc(ctx);
auto* input = ctx.Input<Tensor>("Input"); auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter"); auto* filter = ctx.Input<Tensor>("Filter");
auto* output = ctx.Output<Tensor>("Output"); auto* output = ctx.Output<Tensor>("Output");
shared_ptr<mkldnn::memory> in(new mkldnn::memory(fwd_primitive_desc->src_primitive_desc(), input->data<T>())); shared_ptr<mkldnn::memory> in(new mkldnn::memory(fwd_primitive_desc->src_primitive_desc(), input->data<T>()));
shared_ptr<mkldnn::memory> wgt(new mkldnn::memory(fwd_primitive_desc->weights_primitive_desc(), filter->data<T>())); shared_ptr<mkldnn::memory> wgt(new mkldnn::memory(fwd_primitive_desc->weights_primitive_desc(), filter->data<T>()));
shared_ptr<mkldnn::memory> out(new mkldnn::memory(fwd_primitive_desc->dst_primitive_desc(), output->mutable_data<T>(ctx.GetPlace()))); shared_ptr<mkldnn::memory> out(new mkldnn::memory(fwd_primitive_desc->dst_primitive_desc(), output->mutable_data<T>(ctx.GetPlace())));
shared_ptr<mkldnn::conv_fwd> fwd_primitive(new mkldnn::conv_fwd(*fwd_primitive_desc, *in, *wgt, *out) shared_ptr<mkldnn::conv_fwd> fwd_primitive(new mkldnn::conv_fwd(*fwd_primitive_desc, *in, *wgt, *out));
);
dev_ctx.addMemory(op_key+"_input", in);
ctx.addMemory(op_key+"_input", in); dev_ctx.addMemory(op_key+"_output", out);
ctx.addMemory(op_key+"_output", out); dev_ctx.addMemory(op_key+"_filer", wgt);
ctx.addMemory(op_key+"_filer", wgt); dev_ctx.addPrimitive(op_key+"_fwd", fwd_primitive);
ctx.addPrimitive(op_key+"_fwd", fwd_primitive); dev_ctx.addPrimitiveDesc(op_key+"_fwd_PD", fwd_primitive_desc);
ctx.addPrimitiveDesc(op_key+"_fwd_PD", fwd_primitive_desc);
} }
p = ctx.findPrimitive(op_key + "_fwd"); p = dev_ctx.findPrimitive(op_key + "_fwd");
PADDLE_ENFORCE(p, "Should have Forward Primitive"); PADDLE_ENFORCE(p, "Should have forward Primitive");
PADDLE_ENFORCE(ctx.findMemory(op_unique_key+"_input"), "Should have input memory"); PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_input"), "Should have input memory");
PADDLE_ENFORCE(ctx.findMemory(op_unique_key+"_output"), "Should have output memory"); PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_output"), "Should have output memory");
PADDLE_ENFORCE(ctx.findMemory(op_unique_key+"_filter"), "Should have filter memory"); PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_filter"), "Should have filter memory");
PADDLE_ENFORCE(ctx.findPrimitiveDesc(op_unique_key+"_fwd_PD"), "Should have forward PrimitiveDesc"); PADDLE_ENFORCE(dev_ctx.findPrimitiveDesc(op_unique_key+"_fwd_PD"), "Should have forward PrimitiveDesc");
ctx.submit(p); dev_ctx.submit(p);
ctx.execute(); // the convert primitive should have already contained. dev_ctx.execute(); // the convert primitive should have already contained.
``` ```
...@@ -112,13 +111,13 @@ We should `reorder` the different Layout from other device or to other device. ` ...@@ -112,13 +111,13 @@ We should `reorder` the different Layout from other device or to other device. `
```c++ ```c++
void trans(inputs, ctx) override { void trans(inputs, ctx) override {
if (NoNeedTrasn()) { if (NoNeedTrans()) {
return; return;
} }
// find reorder primitive by op_key from context // find reorder primitive by op_key from context
auto& p = ctx.findPrimitive(op_key + "_reorder_input"); auto& dev_ctx = ctx.template device_context<platform::MKLDNNDeviceContext>();
auto& i = ctx.findMemory(op_key + "_src_input"); auto& p = dev_ctx.findPrimitive(op_key + "_reorder_input");
auto& i = dev_ctx.findMemory(op_key + "_src_input");
if (p == nullptr || i == nullptr || changeSized(i, input)) { if (p == nullptr || i == nullptr || changeSized(i, input)) {
auto prim = createPrimitiveDesc(ctx); auto prim = createPrimitiveDesc(ctx);
...@@ -127,24 +126,24 @@ void trans(inputs, ctx) override { ...@@ -127,24 +126,24 @@ void trans(inputs, ctx) override {
auto dst = createMemory(p->expected_desc(), newbuffer->data); auto dst = createMemory(p->expected_desc(), newbuffer->data);
auto reorder_primitive(new mkldnn::reorder(src, dst)); auto reorder_primitive(new mkldnn::reorder(src, dst));
ctx.addMemory(op_key+"_src_input", src); dev_ctx.addMemory(op_key+"_src_input", src);
ctx.addMemory(op_key+"_input", dst); dev_ctx.addMemory(op_key+"_input", dst);
ctx.addPrimitive(op_key+"_reorder_input", reorder_primitive); dev_ctx.addPrimitive(op_key+"_reorder_input", reorder_primitive);
} }
p = ctx.findPrimitive(op_key + "_reorder_input"); p = dev_ctx.findPrimitive(op_key + "_reorder_input");
PADDLE_ENFORCE(p, "Should have Reorder Primitive"); PADDLE_ENFORCE(p, "Should have Reorder Primitive");
ctx.submit(p); dev_ctx.submit(p);
if (! this->isMKLDNNKernel()) { if (! this->isMKLDNNKernel()) {
// execute immediately only if this is not mkldnn kernel function. // execute immediately only if this is not mkldnn kernel function.
// otherwise, it can be executed with the operator primitive in Compute // otherwise, it can be executed with the operator primitive in Compute
ctx.stream(); dev_ctx.stream();
} }
// after submit, the input tensor in ctx should be changes as the converted one // after submit, the input tensor in ExecutionContext should be changed as the converted one
// there should be another mechanism to ensure this // there should be another mechanism to ensure this
} }
``` ```
### Unit Test ### Unit Test
All the functions should be tested corresponding All the functions should be tested corresponding.
TBD TBD
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册