提交 b0472a33 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] Design document on acquire API (#1420)

上级 c8c30145
# Design Doc: MKL-DNN Acquire API
MKL-DNN kernels that are using MKL-DNN API tend to be quite complex due to:
* number of MKL-DNN api calls needed, which in fact are mostly repeated across all MKL-DNN kernels
* caching mechanism of MKL-DNN objects (conceptually the same across all paddle MKL-DNN kernels)
* still evolving MKL-DNN API which makes paddle MKL-DNN kernels difficult to maintain
Hence Acquire API was created to wrap around MKL-DNN API that address above defined issues.
### Common functionality
Each MKL-DNN kernel is essentially creating MKL-DNN memory objects followed by creation of MKL-DNN computational primitives and as a last step, execution
of created MKL-DNN primitives is triggered. Creation of mentioned MKL-DNN primitives require at least few calls to MKL-DNN API (for each MKL-DNN object) and code is much more complex when caching of created objects is added. Moreover code is pretty similar across MKL-DNN kernels, hence Acquire API was designed to provide easy to use way of creating and caching mentioned MKL-DNN objects. Having common code implemented inside Acquire API, to be used in operators, require less effort when creating given operator. It also makes integration of MKL-DNN kernels shorter and less prone to errors.
### Details of Acquire API
Basic element of Acquire API is so called Handler. There is Basic MKLDNNHandler class which is implementing a code common to all operators using Acquire API . On the picture below rightmost nodes (Nodes grouped with "Basic MKLDNNHandler") represent common functionality used by Softmax and activation MKL-DNN kernels. Apart from basic MKLDNNHandler, there are derived handlers that are implementing functionality that is specific to given operator eg. Constructing caching key for given operator and add some non-standard function for getting workspace memory objects (Nodes grouped with "Derived handlers"). Leftmost nodes are entry functions (Compute) of Softmax and activation MKL-DNN kernels.
![](images/acquire.svg)
Caching MKL-DNN objects is already implemented in Basic MKLDNNHandler, so most of the time when implementing derived handler you do not have to consider caching.
### Usage of Acquire MKL-DNN for MKL-DNN kernels implementation
#### 1. Creating MKLDNNHandler
As a first step one need to create derived handler for his target MKL-DNN kernel (operator). For LRN op it would be LRNMKLDNNHandler that inherits from MKLDNNHandlerT.
Goal of derived handler is to provide operator specific functionality: creating key to caching, creation of Forward and Backward MKL-DNN primitive descriptors.
It is best to look into existing examples of derived handlers and implement new one by analogy.
Example code of calling created LRN MKLDNNHandler:
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
bool is_test = ctx.Attr<bool>("is_test");
auto dims = paddle::framework::vectorize<int>(x->dims());
platform::LRNMKLDNNHandler<T> handler(dims, n, alpha, beta, k, x->format(),
is_test, dev_ctx, ctx.GetPlace(),
ctx.op().Output("Out"));
#### 2. Creating MKL-DNN Memory objects
Once we have a derived handler, then it is time to get needed MKL-DNN memory objects. Memory objects either can wrap Tensor data or allocate data on its own.
Family of functions to get Memory objects are:
* AcquireSrcMemory
* AcquireDstMemory
* AcquireDiffDstMemory
* etc...
They do expect Tensor to be passed as a parameter to each of them so then MKL-DNN memory object is wrapping Tensor (recommended way). If this is not possible
like in a case of some of workspace memory objects then avoiding passing Tensor will trigger creation of MKL-DNN memory object with its own allocation.
Example usage based on LRN MKL-DNN kernel:
auto src_memory = handler.AcquireSrcMemory(x); // x is input tensor of LRN
auto dst_memory = handler.AcquireDstMemory(out); // out is output tensor of LRN
#### 3. Creating MKL-DNN computational primitives
Once We got Handler and MKL-DNN memory objects then we are to get computational MKL-DNN primitive. This is done with AcquireForwardPrimitive (For forward pass op) and AcquireBackwardPrimitive (for grad pass op).
Example usage based on LRN MKL-DNN kernel:
lrn_p = handler.AcquireForwardPrimitive(*src_memory, *dst_memory);
#### 4. Execution of MKL-DNN computational primitives
Having memory objects and computational primitive we may trigger its execution . Example for LRN op:
std::vector<mkldnn::primitive> pipeline = {*lrn_p};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
#### 5. Registering MKL-DNN memory format in corresponding Tensor
Last step is to register MKL-DNN output memory object format inside of Output tensor eg. set Tensor::format_ to MKL-DNN enum that corresponds the way Tensor data is arranged (NCHW, NCHW16C etc.) This enum can be taken from dst memory object (wrapper to Output tensor) in Forward pass or from diff_src memory object (wrapper to X_grad Tensor).
Example of registring MKL-DNN format in output tensor:
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(*dst_memory));
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<!-- Generated by graphviz version 2.38.0 (20140413.2041)
-->
<!-- Title: %3 Pages: 1 -->
<svg width="1392pt" height="428pt"
viewBox="0.00 0.00 1392.00 428.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 424)">
<title>%3</title>
<polygon fill="white" stroke="none" points="-4,4 -4,-424 1388,-424 1388,4 -4,4"/>
<g id="clust1" class="cluster"><title>cluster_A</title>
<polygon fill="none" stroke="black" stroke-dasharray="1,5" points="345,-181 345,-367 894,-367 894,-181 345,-181"/>
<text text-anchor="middle" x="619.5" y="-351.8" font-family="Times,serif" font-size="14.00">Derived Handlers</text>
</g>
<g id="clust2" class="cluster"><title>cluster_B</title>
<polygon fill="none" stroke="black" stroke-dasharray="1,5" points="914,-8 914,-412 1376,-412 1376,-8 914,-8"/>
<text text-anchor="middle" x="1145" y="-396.8" font-family="Times,serif" font-size="14.00">Base MKLDNNHandler</text>
</g>
<!-- Node0x490c380 -->
<g id="node1" class="node"><title>Node0x490c380</title>
<polygon fill="none" stroke="black" points="0,-331 0,-367 317,-367 317,-331 0,-331"/>
<text text-anchor="start" x="8" y="-345.3" font-family="Times,serif" font-size="14.00">SoftmaxMKLDNNKernel::Compute()</text>
</g>
<!-- Node0x4915e90 -->
<g id="node3" class="node"><title>Node0x4915e90</title>
<polygon fill="none" stroke="black" points="353,-299 353,-335 886,-335 886,-299 353,-299"/>
<text text-anchor="start" x="361" y="-313.3" font-family="Times,serif" font-size="14.00">SoftmaxMKLDNNHandler::SoftmaxMKLDNNHandler&lt;forward&gt;()</text>
</g>
<!-- Node0x490c380&#45;&gt;Node0x4915e90 -->
<g id="edge1" class="edge"><title>Node0x490c380&#45;&gt;Node0x4915e90</title>
<path fill="none" stroke="black" stroke-width="2" d="M254,-330.942C254,-322.594 254,-314.5 254,-314.5 254,-314.5 342.91,-314.5 342.91,-314.5"/>
<polygon fill="black" stroke="black" stroke-width="2" points="342.91,-318 352.91,-314.5 342.91,-311 342.91,-318"/>
</g>
<!-- Node0x49164c0 -->
<g id="node5" class="node"><title>Node0x49164c0</title>
<polygon fill="none" stroke="black" points="922,-126 922,-162 1368,-162 1368,-126 922,-126"/>
<text text-anchor="start" x="930" y="-140.3" font-family="Times,serif" font-size="14.00">MKLDNNHandlerT::AcquireSrcMemory()</text>
</g>
<!-- Node0x490c380&#45;&gt;Node0x49164c0 -->
<g id="edge2" class="edge"><title>Node0x490c380&#45;&gt;Node0x49164c0</title>
<path fill="none" stroke="black" d="M317.141,-340.833C331.632,-340.833 341,-340.833 341,-340.833 341,-340.833 341,-146.833 341,-146.833 341,-146.833 911.842,-146.833 911.842,-146.833"/>
<polygon fill="black" stroke="black" points="911.842,-150.333 921.842,-146.833 911.842,-143.333 911.842,-150.333"/>
</g>
<!-- Dst -->
<g id="node6" class="node"><title>Dst</title>
<polygon fill="none" stroke="black" points="922,-17 922,-53 1368,-53 1368,-17 922,-17"/>
<text text-anchor="start" x="930" y="-31.3" font-family="Times,serif" font-size="14.00">MKLDNNHandlerT::AcquireDstMemory()</text>
</g>
<!-- Node0x490c380&#45;&gt;Dst -->
<g id="edge4" class="edge"><title>Node0x490c380&#45;&gt;Dst</title>
<path fill="none" stroke="black" d="M317.011,-338.167C324.519,-338.167 329,-338.167 329,-338.167 329,-338.167 329,-41.1667 329,-41.1667 329,-41.1667 911.755,-41.1667 911.755,-41.1667"/>
<polygon fill="black" stroke="black" points="911.755,-44.6668 921.755,-41.1667 911.755,-37.6668 911.755,-44.6668"/>
</g>
<!-- Node0x491bca0 -->
<g id="node7" class="node"><title>Node0x491bca0</title>
<polygon fill="none" stroke="black" points="922,-235 922,-271 1368,-271 1368,-235 922,-235"/>
<text text-anchor="start" x="930" y="-249.3" font-family="Times,serif" font-size="14.00">MKLDNNHandlerT::AcquireForwardPrimitive()</text>
</g>
<!-- Node0x490c380&#45;&gt;Node0x491bca0 -->
<g id="edge3" class="edge"><title>Node0x490c380&#45;&gt;Node0x491bca0</title>
<path fill="none" stroke="black" d="M64,-330.821C64,-304.571 64,-259.167 64,-259.167 64,-259.167 911.999,-259.167 911.999,-259.167"/>
<polygon fill="black" stroke="black" points="911.999,-262.667 921.999,-259.167 911.999,-255.667 911.999,-262.667"/>
</g>
<!-- Node0x4ab38f0 -->
<g id="node2" class="node"><title>Node0x4ab38f0</title>
<polygon fill="none" stroke="black" points="0,-158 0,-194 317,-194 317,-158 0,-158"/>
<text text-anchor="start" x="8" y="-172.3" font-family="Times,serif" font-size="14.00">MKLDNNActivationKernel::Compute()</text>
</g>
<!-- Node0x4b2e4f0 -->
<g id="node4" class="node"><title>Node0x4b2e4f0</title>
<polygon fill="none" stroke="black" points="353,-190 353,-226 886,-226 886,-190 353,-190"/>
<text text-anchor="start" x="361" y="-204.3" font-family="Times,serif" font-size="14.00">ActivationMKLDNNHandler::ActivationMKLDNNHandler&lt;forward&gt;()</text>
</g>
<!-- Node0x4ab38f0&#45;&gt;Node0x4b2e4f0 -->
<g id="edge10" class="edge"><title>Node0x4ab38f0&#45;&gt;Node0x4b2e4f0</title>
<path fill="none" stroke="black" stroke-width="2" d="M191,-194.058C191,-202.406 191,-210.5 191,-210.5 191,-210.5 342.997,-210.5 342.997,-210.5"/>
<polygon fill="black" stroke="black" stroke-width="2" points="342.998,-214 352.997,-210.5 342.997,-207 342.998,-214"/>
</g>
<!-- Node0x4ab38f0&#45;&gt;Node0x49164c0 -->
<g id="edge7" class="edge"><title>Node0x4ab38f0&#45;&gt;Node0x49164c0</title>
<path fill="none" stroke="black" d="M212,-157.948C212,-147.409 212,-136.167 212,-136.167 212,-136.167 911.82,-136.167 911.82,-136.167"/>
<polygon fill="black" stroke="black" points="911.82,-139.667 921.82,-136.167 911.82,-132.667 911.82,-139.667"/>
</g>
<!-- Node0x4ab38f0&#45;&gt;Dst -->
<g id="edge8" class="edge"><title>Node0x4ab38f0&#45;&gt;Dst</title>
<path fill="none" stroke="black" d="M106,-157.954C106,-118.657 106,-28.8333 106,-28.8333 106,-28.8333 911.789,-28.8333 911.789,-28.8333"/>
<polygon fill="black" stroke="black" points="911.789,-32.3334 921.789,-28.8333 911.789,-25.3334 911.789,-32.3334"/>
</g>
<!-- Node0x4ab38f0&#45;&gt;Node0x491bca0 -->
<g id="edge9" class="edge"><title>Node0x4ab38f0&#45;&gt;Node0x491bca0</title>
<path fill="none" stroke="black" d="M127,-194.241C127,-215.229 127,-246.833 127,-246.833 127,-246.833 911.91,-246.833 911.91,-246.833"/>
<polygon fill="black" stroke="black" points="911.91,-250.333 921.91,-246.833 911.91,-243.333 911.91,-250.333"/>
</g>
<!-- Node0x496cfc0 -->
<g id="node8" class="node"><title>Node0x496cfc0</title>
<polygon fill="none" stroke="black" points="922,-344 922,-380 1368,-380 1368,-344 922,-344"/>
<text text-anchor="start" x="930" y="-358.3" font-family="Times,serif" font-size="14.00">MKLDNNHandlerT::AcquireForwardPrimitiveDescriptor()</text>
</g>
<!-- Node0x4915e90&#45;&gt;Node0x496cfc0 -->
<g id="edge5" class="edge"><title>Node0x4915e90&#45;&gt;Node0x496cfc0</title>
<path fill="none" stroke="black" d="M886.099,-317C1016.18,-317 1145,-317 1145,-317 1145,-317 1145,-333.956 1145,-333.956"/>
<polygon fill="black" stroke="black" points="1141.5,-333.956 1145,-343.956 1148.5,-333.956 1141.5,-333.956"/>
</g>
<!-- Node0x4b2e4f0&#45;&gt;Node0x496cfc0 -->
<g id="edge6" class="edge"><title>Node0x4b2e4f0&#45;&gt;Node0x496cfc0</title>
<path fill="none" stroke="black" d="M886.203,-208C897.397,-208 904,-208 904,-208 904,-208 904,-355.5 904,-355.5 904,-355.5 911.721,-355.5 911.721,-355.5"/>
<polygon fill="black" stroke="black" points="911.721,-359 921.721,-355.5 911.721,-352 911.721,-359"/>
</g>
</g>
</svg>
MKL-DNN Acquire API
--------------------------------------
.. toctree::
:maxdepth: 1
acquire_api.md
digraph {
rankdir=LR
weight=0.5
concentrate=true
splines=ortho
newrank=true
nodesep=1
node[width=4.4,shape=box]
Node0x490c380 [shape=record,label="SoftmaxMKLDNNKernel::Compute()\l"];
Node0x4ab38f0 [shape=record,label="MKLDNNActivationKernel::Compute()\l"];
subgraph cluster_A {
label="Derived Handlers"
node[width=7.4,shape=box]
style=dotted
// Dummy[shape=record,label="", color=invis];
Node0x4915e90 [shape=record,label="SoftmaxMKLDNNHandler::SoftmaxMKLDNNHandler\<forward\>()\l"];
Node0x4b2e4f0 [shape=record,label="ActivationMKLDNNHandler::ActivationMKLDNNHandler\<forward\>()\l"];
}
subgraph cluster_B {
label="Base MKLDNNHandler"
style=dotted
node[width=6.2,shape=box]
Node0x49164c0 [shape=record,label="MKLDNNHandlerT::AcquireSrcMemory()\l"];
Dst[shape=record,label="MKLDNNHandlerT::AcquireDstMemory()\l"];
Node0x491bca0 [shape=record,label="MKLDNNHandlerT::AcquireForwardPrimitive()\l"];
Node0x496cfc0 [shape=record,label="MKLDNNHandlerT::AcquireForwardPrimitiveDescriptor()\l"];
}
Node0x490c380 -> Node0x4915e90[style="bold"];
Node0x490c380 -> Node0x49164c0;
Node0x490c380 -> Node0x491bca0;
Node0x490c380 -> Dst;
Node0x4915e90 -> Node0x496cfc0;
{rank=same Node0x4ab38f0 Node0x490c380 } // Compute level
{rank=same Node0x4915e90 Node0x4b2e4f0 } // Derived Handler level
{rank=same Node0x49164c0 Dst Node0x491bca0 Node0x496cfc0 } // Compute level
Node0x4b2e4f0 -> Node0x496cfc0
Node0x4ab38f0 -> Node0x49164c0
Node0x4ab38f0 -> Dst
Node0x4ab38f0 -> Node0x491bca0
Node0x4ab38f0 -> Node0x4b2e4f0[style="bold"]
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册