提交 c8e3e20e 编写于 作者: T Travis CI

Deploy to GitHub Pages: fac25fb5

上级 c3c30bb0
# Design Doc: Add MKLDNN Kernel in Fluid Operator
## Principles
First of all, we should follow some basical principles like:
1. [How to write a new operator](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_en.md). We are trying to add a new kind of kernel into operators, so basically we should follow this doc.
2. [Supporting new Device/Library](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/support_new_device.md). Since MKLDNN is a new library to fluid, we should add `MKLDNNDeviceContext` and maybe `mkldnn_helper.h`, just like [cudnn_helper.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/cudnn_helper.h).
3. [Switch Kernel](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md). Another important point is that we should ensure the data synchronization between different kernel types, which is this [topic](https://github.com/PaddlePaddle/Paddle/issues/6549). So basically we should override `GetExpectedKernelType` and `trans` functions to support switching kernels.
4. [The Keys of Operator Kernel Type](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md). Kernel Type is a pivotal conception which can record the `Place`, `Library`, `DataType` and `Layout`.
## Sulution
In general, there are four parts we should follow to run a MKL-DNN primitive.
- Create a primitive descriptor that describe this operator
- Create a primitive itself by primitive descriptor and the engine
- Create all memory buffers that primitive needed
- Launch a stream to execute the primitive created
More details can refer to [here](http://01org.github.io/mkl-dnn).
It's better to avoid reinitialization of primitives and memory handles in the first three stages in every iteration. \
So we plan to create a map to record all the `primitive` and `memory`, which should not take too much memories as discussed [here](https://github.com/PaddlePaddle/Paddle/issues/6822).
It's assumed that following three conditions should be satisfied.
1. there is a unique key for each operator instance. May be the actual name of `Output Tensor`.
2. the `Input Tensor` inside `Compute` function is the one after converted.
3. we can get the phase(eg. `is_test`) inside `Compute` function, otherwise we need to expose this attribue to user.
### Compute
The algorithm of `Compute` would be described as follow, let's take conv like an example.
```c++
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace.");
PADDLE_ENFORCE(platform::is_mkldnn_library(ctx.GetLibrary()), "It must use MKLDNN Library.");
auto& dev_ctx = ctx.template device_context<platform::MKLDNNDeviceContext>();
// find primitive by unique key from mkldnn context
// the op_key should be a unique name of this op instance
auto& p = dev_ctx.findPrimitive(op_key + "_fwd");
// assuming the input tensor inside this compute function is the one after converted
// this point should be guarantee by another mechanism
auto& i = dev_ctx.findMemory(op_key + "_input");
if (p == nullptr || i == nullptr || inputSizeChanged(p, i)) {
auto fwd_primitive_desc = createPrimitiveDesc(ctx);
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* output = ctx.Output<Tensor>("Output");
shared_ptr<mkldnn::memory> in(new mkldnn::memory(fwd_primitive_desc->src_primitive_desc(), input->data<T>()));
shared_ptr<mkldnn::memory> wgt(new mkldnn::memory(fwd_primitive_desc->weights_primitive_desc(), filter->data<T>()));
shared_ptr<mkldnn::memory> out(new mkldnn::memory(fwd_primitive_desc->dst_primitive_desc(), output->mutable_data<T>(ctx.GetPlace())));
shared_ptr<mkldnn::conv_fwd> fwd_primitive(new mkldnn::conv_fwd(*fwd_primitive_desc, *in, *wgt, *out));
dev_ctx.addMemory(op_key+"_input", in);
dev_ctx.addMemory(op_key+"_output", out);
dev_ctx.addMemory(op_key+"_filer", wgt);
dev_ctx.addPrimitive(op_key+"_fwd", fwd_primitive);
dev_ctx.addPrimitiveDesc(op_key+"_fwd_PD", fwd_primitive_desc);
}
p = dev_ctx.findPrimitive(op_key + "_fwd");
PADDLE_ENFORCE(p, "Should have forward Primitive");
PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_input"), "Should have input memory");
PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_output"), "Should have output memory");
PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_filter"), "Should have filter memory");
PADDLE_ENFORCE(dev_ctx.findPrimitiveDesc(op_unique_key+"_fwd_PD"), "Should have forward PrimitiveDesc");
dev_ctx.submit(p);
dev_ctx.execute(); // the convert primitive should have already contained.
```
The `createPrimitiveDesc` returns the primitive descripotor of this operator, would be like this:
```c++
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* output = ctx.Output<Tensor>("Output");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
algorithm algo = static_cast<algorithm>(ctx.Attr<int>("convolution_algorithm_option"));
prop_kind pk = ctx.Attr<bool>("is_test") ? prop_kind::forward_inference : prop_kind::forward_training;
auto fwd_desc = mkldnn::conv_fwd::desc(/* all the setting above*/);
shared_ptr<mkldnn::conv_fwd::primitive_desc> fwd_primitive_desc(new mkldnn::conv_fwd::primitive_desc(fwd_desc, ctx.getEngine()));
return fwd_primitive_desc;
}
```
### MKLDNNDeviceContext
`MKLDNNDeviceContext`, which is very straightforward, should contain some base information like: `stream`, `engine` and the map needed.
### mkldnn_helper
Some functions would be put in `paddle/platform/mkldnn_helper.h`.
- create MKLDNN memories
- create MKLDNN primitives
- error check function
- etc
### Kernel Switch
We should `reorder` the different Layout from other device or to other device. `GetExpectedKernelType` and `trans` functions can help us to implement it.
`GetExpectedKernelType` should get the context, and this operator can return the best `KernelType`.
`trans` would be like this:
```c++
void trans(inputs, ctx) override {
if (NoNeedTrans()) {
return;
}
// find reorder primitive by op_key from context
auto& dev_ctx = ctx.template device_context<platform::MKLDNNDeviceContext>();
auto& p = dev_ctx.findPrimitive(op_key + "_reorder_input");
auto& i = dev_ctx.findMemory(op_key + "_src_input");
if (p == nullptr || i == nullptr || changeSized(i, input)) {
auto prim = createPrimitiveDesc(ctx);
auto src = createMemory(memoryDesc(input->dims(), actual_layout), input->data);
auto newbuffer = paddle::memory::Alloc(ctx.GetPlace(), input->size_in_bytes());
auto dst = createMemory(p->expected_desc(), newbuffer->data);
auto reorder_primitive(new mkldnn::reorder(src, dst));
dev_ctx.addMemory(op_key+"_src_input", src);
dev_ctx.addMemory(op_key+"_input", dst);
dev_ctx.addPrimitive(op_key+"_reorder_input", reorder_primitive);
}
p = dev_ctx.findPrimitive(op_key + "_reorder_input");
PADDLE_ENFORCE(p, "Should have Reorder Primitive");
dev_ctx.submit(p);
if (! this->isMKLDNNKernel()) {
// execute immediately only if this is not mkldnn kernel function.
// otherwise, it can be executed with the operator primitive in Compute
dev_ctx.stream();
}
// after submit, the input tensor in ExecutionContext should be changed as the converted one
// there should be another mechanism to ensure this
}
```
### Unit Test
All the functions should be tested corresponding.
TBD
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Design Doc: Add MKLDNN Kernel in Fluid Operator &mdash; PaddlePaddle documentation</title>
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="index" title="Index"
href="../../genindex.html"/>
<link rel="search" title="Search" href="../../search.html"/>
<link rel="top" title="PaddlePaddle documentation" href="../../index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/perfect-scrollbar/0.6.14/css/perfect-scrollbar.min.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/override.css" type="text/css" />
<script>
var _hmt = _hmt || [];
(function() {
var hm = document.createElement("script");
hm.src = "//hm.baidu.com/hm.js?b9a314ab40d04d805655aab1deee08ba";
var s = document.getElementsByTagName("script")[0];
s.parentNode.insertBefore(hm, s);
})();
</script>
<script src="../../_static/js/modernizr.min.js"></script>
</head>
<body class="wy-body-for-nav" role="document">
<header class="site-header">
<div class="site-logo">
<a href="/"><img src="../../_static/images/PP_w.png"></a>
</div>
<div class="site-nav-links">
<div class="site-menu">
<a class="fork-on-github" href="https://github.com/PaddlePaddle/Paddle" target="_blank"><i class="fa fa-github"></i>Fork me on Github</a>
<div class="language-switcher dropdown">
<a type="button" data-toggle="dropdown">
<span>English</span>
<i class="fa fa-angle-up"></i>
<i class="fa fa-angle-down"></i>
</a>
<ul class="dropdown-menu">
<li><a href="/doc_cn">中文</a></li>
<li><a href="/doc">English</a></li>
</ul>
</div>
<ul class="site-page-links">
<li><a href="/">Home</a></li>
</ul>
</div>
<div class="doc-module">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../getstarted/index_en.html">GET STARTED</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../howto/index_en.html">HOW TO</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/index_en.html">API</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../mobile/index_en.html">MOBILE</a></li>
</ul>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
</div>
</header>
<div class="main-content-wrap">
<nav class="doc-menu-vertical" role="navigation">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../getstarted/index_en.html">GET STARTED</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../getstarted/build_and_install/index_en.html">Install and Build</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../getstarted/build_and_install/pip_install_en.html">Install Using pip</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../getstarted/build_and_install/docker_install_en.html">Run in Docker Containers</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../howto/dev/build_en.html">Build using Docker</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../getstarted/build_and_install/build_from_source_en.html">Build from Sources</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../howto/index_en.html">HOW TO</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../howto/usage/cmd_parameter/index_en.html">Set Command-line Parameters</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../howto/usage/cmd_parameter/use_case_en.html">Use Case</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../howto/usage/cmd_parameter/arguments_en.html">Argument Outline</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../howto/usage/cmd_parameter/detail_introduction_en.html">Detail Description</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../howto/usage/cluster/cluster_train_en.html">Distributed Training</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../howto/usage/cluster/fabric_en.html">fabric</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../howto/usage/cluster/openmpi_en.html">openmpi</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../howto/usage/cluster/k8s_en.html">kubernetes</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../howto/usage/cluster/k8s_aws_en.html">kubernetes on AWS</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../howto/dev/new_layer_en.html">Write New Layers</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../howto/dev/contribute_to_paddle_en.html">Contribute Code</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../howto/dev/write_docs_en.html">Contribute Documentation</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../howto/deep_model/rnn/index_en.html">RNN Models</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../howto/deep_model/rnn/rnn_config_en.html">RNN Configuration</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../howto/optimization/gpu_profiling_en.html">Tune GPU Performance</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../api/index_en.html">API</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../api/v2/model_configs.html">Model Configuration</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/config/activation.html">Activation</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/config/layer.html">Layers</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/config/evaluators.html">Evaluators</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/config/optimizer.html">Optimizer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/config/pooling.html">Pooling</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/config/networks.html">Networks</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/config/attr.html">Parameter Attribute</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../api/v2/data.html">Data Reader Interface and DataSets</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/data/data_reader.html">Data Reader Interface</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/data/image.html">Image Interface</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/data/dataset.html">Dataset</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../api/v2/run_logic.html">Training and Inference</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../api/v2/fluid.html">Fluid</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/layers.html">Layers</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/data_feeder.html">DataFeeder</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/executor.html">Executor</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/initializer.html">Initializer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/evaluator.html">Evaluator</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/nets.html">Nets</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/optimizer.html">Optimizer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/param_attr.html">ParamAttr</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/profiler.html">Profiler</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/regularizer.html">Regularizer</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../mobile/index_en.html">MOBILE</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../mobile/cross_compiling_for_android_en.html">Build PaddlePaddle for Android</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../mobile/cross_compiling_for_ios_en.html">PaddlePaddle Compiling Guide for iOS</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../mobile/cross_compiling_for_raspberry_en.html">Build PaddlePaddle for Raspberry Pi</a></li>
</ul>
</li>
</ul>
</nav>
<section class="doc-content-wrap">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li>Design Doc: Add MKLDNN Kernel in Fluid Operator</li>
</ul>
</div>
<div class="wy-nav-content" id="doc-content">
<div class="rst-content">
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="section" id="design-doc-add-mkldnn-kernel-in-fluid-operator">
<span id="design-doc-add-mkldnn-kernel-in-fluid-operator"></span><h1>Design Doc: Add MKLDNN Kernel in Fluid Operator<a class="headerlink" href="#design-doc-add-mkldnn-kernel-in-fluid-operator" title="Permalink to this headline"></a></h1>
<div class="section" id="principles">
<span id="principles"></span><h2>Principles<a class="headerlink" href="#principles" title="Permalink to this headline"></a></h2>
<p>First of all, we should follow some basical principles like:</p>
<ol class="simple">
<li><a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_en.md">How to write a new operator</a>. We are trying to add a new kind of kernel into operators, so basically we should follow this doc.</li>
<li><a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/support_new_device.md">Supporting new Device/Library</a>. Since MKLDNN is a new library to fluid, we should add <code class="docutils literal"><span class="pre">MKLDNNDeviceContext</span></code> and maybe <code class="docutils literal"><span class="pre">mkldnn_helper.h</span></code>, just like <a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/cudnn_helper.h">cudnn_helper.h</a>.</li>
<li><a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md">Switch Kernel</a>. Another important point is that we should ensure the data synchronization between different kernel types, which is this <a class="reference external" href="https://github.com/PaddlePaddle/Paddle/issues/6549">topic</a>. So basically we should override <code class="docutils literal"><span class="pre">GetExpectedKernelType</span></code> and <code class="docutils literal"><span class="pre">trans</span></code> functions to support switching kernels.</li>
<li><a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md">The Keys of Operator Kernel Type</a>. Kernel Type is a pivotal conception which can record the <code class="docutils literal"><span class="pre">Place</span></code>, <code class="docutils literal"><span class="pre">Library</span></code>, <code class="docutils literal"><span class="pre">DataType</span></code> and <code class="docutils literal"><span class="pre">Layout</span></code>.</li>
</ol>
</div>
<div class="section" id="sulution">
<span id="sulution"></span><h2>Sulution<a class="headerlink" href="#sulution" title="Permalink to this headline"></a></h2>
<p>In general, there are four parts we should follow to run a MKL-DNN primitive.</p>
<ul class="simple">
<li>Create a primitive descriptor that describe this operator</li>
<li>Create a primitive itself by primitive descriptor and the engine</li>
<li>Create all memory buffers that primitive needed</li>
<li>Launch a stream to execute the primitive created
More details can refer to <a class="reference external" href="http://01org.github.io/mkl-dnn">here</a>.</li>
</ul>
<p>It&#8217;s better to avoid reinitialization of primitives and memory handles in the first three stages in every iteration. So we plan to create a map to record all the <code class="docutils literal"><span class="pre">primitive</span></code> and <code class="docutils literal"><span class="pre">memory</span></code>, which should not take too much memories as discussed <a class="reference external" href="https://github.com/PaddlePaddle/Paddle/issues/6822">here</a>.</p>
<p>It&#8217;s assumed that following three conditions should be satisfied.</p>
<ol class="simple">
<li>there is a unique key for each operator instance. May be the actual name of <code class="docutils literal"><span class="pre">Output</span> <span class="pre">Tensor</span></code>.</li>
<li>the <code class="docutils literal"><span class="pre">Input</span> <span class="pre">Tensor</span></code> inside <code class="docutils literal"><span class="pre">Compute</span></code> function is the one after converted.</li>
<li>we can get the phase(eg. <code class="docutils literal"><span class="pre">is_test</span></code>) inside <code class="docutils literal"><span class="pre">Compute</span></code> function, otherwise we need to expose this attribue to user.</li>
</ol>
<div class="section" id="compute">
<span id="compute"></span><h3>Compute<a class="headerlink" href="#compute" title="Permalink to this headline"></a></h3>
<p>The algorithm of <code class="docutils literal"><span class="pre">Compute</span></code> would be described as follow, let&#8217;s take conv like an example.</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="n">PADDLE_ENFORCE</span><span class="p">(</span><span class="n">platform</span><span class="o">::</span><span class="n">is_cpu_place</span><span class="p">(</span><span class="n">ctx</span><span class="p">.</span><span class="n">GetPlace</span><span class="p">()),</span> <span class="s">&quot;It must use CPUPlace.&quot;</span><span class="p">);</span>
<span class="n">PADDLE_ENFORCE</span><span class="p">(</span><span class="n">platform</span><span class="o">::</span><span class="n">is_mkldnn_library</span><span class="p">(</span><span class="n">ctx</span><span class="p">.</span><span class="n">GetLibrary</span><span class="p">()),</span> <span class="s">&quot;It must use MKLDNN Library.&quot;</span><span class="p">);</span>
<span class="k">auto</span><span class="o">&amp;</span> <span class="n">dev_ctx</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="k">template</span> <span class="n">device_context</span><span class="o">&lt;</span><span class="n">platform</span><span class="o">::</span><span class="n">MKLDNNDeviceContext</span><span class="o">&gt;</span><span class="p">();</span>
<span class="c1">// find primitive by unique key from mkldnn context</span>
<span class="c1">// the op_key should be a unique name of this op instance</span>
<span class="k">auto</span><span class="o">&amp;</span> <span class="n">p</span> <span class="o">=</span> <span class="n">dev_ctx</span><span class="p">.</span><span class="n">findPrimitive</span><span class="p">(</span><span class="n">op_key</span> <span class="o">+</span> <span class="s">&quot;_fwd&quot;</span><span class="p">);</span>
<span class="c1">// assuming the input tensor inside this compute function is the one after converted</span>
<span class="c1">// this point should be guarantee by another mechanism</span>
<span class="k">auto</span><span class="o">&amp;</span> <span class="n">i</span> <span class="o">=</span> <span class="n">dev_ctx</span><span class="p">.</span><span class="n">findMemory</span><span class="p">(</span><span class="n">op_key</span> <span class="o">+</span> <span class="s">&quot;_input&quot;</span><span class="p">);</span>
<span class="k">if</span> <span class="p">(</span><span class="n">p</span> <span class="o">==</span> <span class="k">nullptr</span> <span class="o">||</span> <span class="n">i</span> <span class="o">==</span> <span class="k">nullptr</span> <span class="o">||</span> <span class="n">inputSizeChanged</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span> <span class="p">{</span>
<span class="k">auto</span> <span class="n">fwd_primitive_desc</span> <span class="o">=</span> <span class="n">createPrimitiveDesc</span><span class="p">(</span><span class="n">ctx</span><span class="p">);</span>
<span class="k">auto</span><span class="o">*</span> <span class="n">input</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Input</span><span class="o">&lt;</span><span class="n">Tensor</span><span class="o">&gt;</span><span class="p">(</span><span class="s">&quot;Input&quot;</span><span class="p">);</span>
<span class="k">auto</span><span class="o">*</span> <span class="n">filter</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Input</span><span class="o">&lt;</span><span class="n">Tensor</span><span class="o">&gt;</span><span class="p">(</span><span class="s">&quot;Filter&quot;</span><span class="p">);</span>
<span class="k">auto</span><span class="o">*</span> <span class="n">output</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Output</span><span class="o">&lt;</span><span class="n">Tensor</span><span class="o">&gt;</span><span class="p">(</span><span class="s">&quot;Output&quot;</span><span class="p">);</span>
<span class="n">shared_ptr</span><span class="o">&lt;</span><span class="n">mkldnn</span><span class="o">::</span><span class="n">memory</span><span class="o">&gt;</span> <span class="n">in</span><span class="p">(</span><span class="k">new</span> <span class="n">mkldnn</span><span class="o">::</span><span class="n">memory</span><span class="p">(</span><span class="n">fwd_primitive_desc</span><span class="o">-&gt;</span><span class="n">src_primitive_desc</span><span class="p">(),</span> <span class="n">input</span><span class="o">-&gt;</span><span class="n">data</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">()));</span>
<span class="n">shared_ptr</span><span class="o">&lt;</span><span class="n">mkldnn</span><span class="o">::</span><span class="n">memory</span><span class="o">&gt;</span> <span class="n">wgt</span><span class="p">(</span><span class="k">new</span> <span class="n">mkldnn</span><span class="o">::</span><span class="n">memory</span><span class="p">(</span><span class="n">fwd_primitive_desc</span><span class="o">-&gt;</span><span class="n">weights_primitive_desc</span><span class="p">(),</span> <span class="n">filter</span><span class="o">-&gt;</span><span class="n">data</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">()));</span>
<span class="n">shared_ptr</span><span class="o">&lt;</span><span class="n">mkldnn</span><span class="o">::</span><span class="n">memory</span><span class="o">&gt;</span> <span class="n">out</span><span class="p">(</span><span class="k">new</span> <span class="n">mkldnn</span><span class="o">::</span><span class="n">memory</span><span class="p">(</span><span class="n">fwd_primitive_desc</span><span class="o">-&gt;</span><span class="n">dst_primitive_desc</span><span class="p">(),</span> <span class="n">output</span><span class="o">-&gt;</span><span class="n">mutable_data</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">(</span><span class="n">ctx</span><span class="p">.</span><span class="n">GetPlace</span><span class="p">())));</span>
<span class="n">shared_ptr</span><span class="o">&lt;</span><span class="n">mkldnn</span><span class="o">::</span><span class="n">conv_fwd</span><span class="o">&gt;</span> <span class="n">fwd_primitive</span><span class="p">(</span><span class="k">new</span> <span class="n">mkldnn</span><span class="o">::</span><span class="n">conv_fwd</span><span class="p">(</span><span class="o">*</span><span class="n">fwd_primitive_desc</span><span class="p">,</span> <span class="o">*</span><span class="n">in</span><span class="p">,</span> <span class="o">*</span><span class="n">wgt</span><span class="p">,</span> <span class="o">*</span><span class="n">out</span><span class="p">));</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">addMemory</span><span class="p">(</span><span class="n">op_key</span><span class="o">+</span><span class="s">&quot;_input&quot;</span><span class="p">,</span> <span class="n">in</span><span class="p">);</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">addMemory</span><span class="p">(</span><span class="n">op_key</span><span class="o">+</span><span class="s">&quot;_output&quot;</span><span class="p">,</span> <span class="n">out</span><span class="p">);</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">addMemory</span><span class="p">(</span><span class="n">op_key</span><span class="o">+</span><span class="s">&quot;_filer&quot;</span><span class="p">,</span> <span class="n">wgt</span><span class="p">);</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">addPrimitive</span><span class="p">(</span><span class="n">op_key</span><span class="o">+</span><span class="s">&quot;_fwd&quot;</span><span class="p">,</span> <span class="n">fwd_primitive</span><span class="p">);</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">addPrimitiveDesc</span><span class="p">(</span><span class="n">op_key</span><span class="o">+</span><span class="s">&quot;_fwd_PD&quot;</span><span class="p">,</span> <span class="n">fwd_primitive_desc</span><span class="p">);</span>
<span class="p">}</span>
<span class="n">p</span> <span class="o">=</span> <span class="n">dev_ctx</span><span class="p">.</span><span class="n">findPrimitive</span><span class="p">(</span><span class="n">op_key</span> <span class="o">+</span> <span class="s">&quot;_fwd&quot;</span><span class="p">);</span>
<span class="n">PADDLE_ENFORCE</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="s">&quot;Should have forward Primitive&quot;</span><span class="p">);</span>
<span class="n">PADDLE_ENFORCE</span><span class="p">(</span><span class="n">dev_ctx</span><span class="p">.</span><span class="n">findMemory</span><span class="p">(</span><span class="n">op_unique_key</span><span class="o">+</span><span class="s">&quot;_input&quot;</span><span class="p">),</span> <span class="s">&quot;Should have input memory&quot;</span><span class="p">);</span>
<span class="n">PADDLE_ENFORCE</span><span class="p">(</span><span class="n">dev_ctx</span><span class="p">.</span><span class="n">findMemory</span><span class="p">(</span><span class="n">op_unique_key</span><span class="o">+</span><span class="s">&quot;_output&quot;</span><span class="p">),</span> <span class="s">&quot;Should have output memory&quot;</span><span class="p">);</span>
<span class="n">PADDLE_ENFORCE</span><span class="p">(</span><span class="n">dev_ctx</span><span class="p">.</span><span class="n">findMemory</span><span class="p">(</span><span class="n">op_unique_key</span><span class="o">+</span><span class="s">&quot;_filter&quot;</span><span class="p">),</span> <span class="s">&quot;Should have filter memory&quot;</span><span class="p">);</span>
<span class="n">PADDLE_ENFORCE</span><span class="p">(</span><span class="n">dev_ctx</span><span class="p">.</span><span class="n">findPrimitiveDesc</span><span class="p">(</span><span class="n">op_unique_key</span><span class="o">+</span><span class="s">&quot;_fwd_PD&quot;</span><span class="p">),</span> <span class="s">&quot;Should have forward PrimitiveDesc&quot;</span><span class="p">);</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">submit</span><span class="p">(</span><span class="n">p</span><span class="p">);</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">execute</span><span class="p">();</span> <span class="c1">// the convert primitive should have already contained.</span>
</pre></div>
</div>
<p>The <code class="docutils literal"><span class="pre">createPrimitiveDesc</span></code> returns the primitive descripotor of this operator, would be like this:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">auto</span><span class="o">*</span> <span class="n">input</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Input</span><span class="o">&lt;</span><span class="n">Tensor</span><span class="o">&gt;</span><span class="p">(</span><span class="s">&quot;Input&quot;</span><span class="p">);</span>
<span class="k">auto</span><span class="o">*</span> <span class="n">filter</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Input</span><span class="o">&lt;</span><span class="n">Tensor</span><span class="o">&gt;</span><span class="p">(</span><span class="s">&quot;Filter&quot;</span><span class="p">);</span>
<span class="k">auto</span><span class="o">*</span> <span class="n">output</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Output</span><span class="o">&lt;</span><span class="n">Tensor</span><span class="o">&gt;</span><span class="p">(</span><span class="s">&quot;Output&quot;</span><span class="p">);</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="n">strides</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Attr</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&gt;</span><span class="p">(</span><span class="s">&quot;strides&quot;</span><span class="p">);</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="n">paddings</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Attr</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&gt;</span><span class="p">(</span><span class="s">&quot;paddings&quot;</span><span class="p">);</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="n">dilations</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Attr</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&gt;</span><span class="p">(</span><span class="s">&quot;dilations&quot;</span><span class="p">);</span>
<span class="kt">int</span> <span class="n">groups</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Attr</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span><span class="p">(</span><span class="s">&quot;groups&quot;</span><span class="p">);</span>
<span class="n">algorithm</span> <span class="n">algo</span> <span class="o">=</span> <span class="k">static_cast</span><span class="o">&lt;</span><span class="n">algorithm</span><span class="o">&gt;</span><span class="p">(</span><span class="n">ctx</span><span class="p">.</span><span class="n">Attr</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span><span class="p">(</span><span class="s">&quot;convolution_algorithm_option&quot;</span><span class="p">));</span>
<span class="n">prop_kind</span> <span class="n">pk</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Attr</span><span class="o">&lt;</span><span class="kt">bool</span><span class="o">&gt;</span><span class="p">(</span><span class="s">&quot;is_test&quot;</span><span class="p">)</span> <span class="o">?</span> <span class="n">prop_kind</span><span class="o">::</span><span class="nl">forward_inference</span> <span class="p">:</span> <span class="n">prop_kind</span><span class="o">::</span><span class="n">forward_training</span><span class="p">;</span>
<span class="k">auto</span> <span class="n">fwd_desc</span> <span class="o">=</span> <span class="n">mkldnn</span><span class="o">::</span><span class="n">conv_fwd</span><span class="o">::</span><span class="n">desc</span><span class="p">(</span><span class="cm">/* all the setting above*/</span><span class="p">);</span>
<span class="n">shared_ptr</span><span class="o">&lt;</span><span class="n">mkldnn</span><span class="o">::</span><span class="n">conv_fwd</span><span class="o">::</span><span class="n">primitive_desc</span><span class="o">&gt;</span> <span class="n">fwd_primitive_desc</span><span class="p">(</span><span class="k">new</span> <span class="n">mkldnn</span><span class="o">::</span><span class="n">conv_fwd</span><span class="o">::</span><span class="n">primitive_desc</span><span class="p">(</span><span class="n">fwd_desc</span><span class="p">,</span> <span class="n">ctx</span><span class="p">.</span><span class="n">getEngine</span><span class="p">()));</span>
<span class="k">return</span> <span class="n">fwd_primitive_desc</span><span class="p">;</span>
<span class="p">}</span>
</pre></div>
</div>
</div>
<div class="section" id="mkldnndevicecontext">
<span id="mkldnndevicecontext"></span><h3>MKLDNNDeviceContext<a class="headerlink" href="#mkldnndevicecontext" title="Permalink to this headline"></a></h3>
<p><code class="docutils literal"><span class="pre">MKLDNNDeviceContext</span></code>, which is very straightforward, should contain some base information like: <code class="docutils literal"><span class="pre">stream</span></code>, <code class="docutils literal"><span class="pre">engine</span></code> and the map needed.</p>
</div>
<div class="section" id="mkldnn-helper">
<span id="mkldnn-helper"></span><h3>mkldnn_helper<a class="headerlink" href="#mkldnn-helper" title="Permalink to this headline"></a></h3>
<p>Some functions would be put in <code class="docutils literal"><span class="pre">paddle/platform/mkldnn_helper.h</span></code>.</p>
<ul class="simple">
<li>create MKLDNN memories</li>
<li>create MKLDNN primitives</li>
<li>error check function</li>
<li>etc</li>
</ul>
</div>
<div class="section" id="kernel-switch">
<span id="kernel-switch"></span><h3>Kernel Switch<a class="headerlink" href="#kernel-switch" title="Permalink to this headline"></a></h3>
<p>We should <code class="docutils literal"><span class="pre">reorder</span></code> the different Layout from other device or to other device. <code class="docutils literal"><span class="pre">GetExpectedKernelType</span></code> and <code class="docutils literal"><span class="pre">trans</span></code> functions can help us to implement it.</p>
<p><code class="docutils literal"><span class="pre">GetExpectedKernelType</span></code> should get the context, and this operator can return the best <code class="docutils literal"><span class="pre">KernelType</span></code>.
<code class="docutils literal"><span class="pre">trans</span></code> would be like this:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span><span class="kt">void</span> <span class="nf">trans</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">ctx</span><span class="p">)</span> <span class="k">override</span> <span class="p">{</span>
<span class="k">if</span> <span class="p">(</span><span class="n">NoNeedTrans</span><span class="p">())</span> <span class="p">{</span>
<span class="k">return</span><span class="p">;</span>
<span class="p">}</span>
<span class="c1">// find reorder primitive by op_key from context</span>
<span class="k">auto</span><span class="o">&amp;</span> <span class="n">dev_ctx</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="k">template</span> <span class="n">device_context</span><span class="o">&lt;</span><span class="n">platform</span><span class="o">::</span><span class="n">MKLDNNDeviceContext</span><span class="o">&gt;</span><span class="p">();</span>
<span class="k">auto</span><span class="o">&amp;</span> <span class="n">p</span> <span class="o">=</span> <span class="n">dev_ctx</span><span class="p">.</span><span class="n">findPrimitive</span><span class="p">(</span><span class="n">op_key</span> <span class="o">+</span> <span class="s">&quot;_reorder_input&quot;</span><span class="p">);</span>
<span class="k">auto</span><span class="o">&amp;</span> <span class="n">i</span> <span class="o">=</span> <span class="n">dev_ctx</span><span class="p">.</span><span class="n">findMemory</span><span class="p">(</span><span class="n">op_key</span> <span class="o">+</span> <span class="s">&quot;_src_input&quot;</span><span class="p">);</span>
<span class="k">if</span> <span class="p">(</span><span class="n">p</span> <span class="o">==</span> <span class="k">nullptr</span> <span class="o">||</span> <span class="n">i</span> <span class="o">==</span> <span class="k">nullptr</span> <span class="o">||</span> <span class="n">changeSized</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">input</span><span class="p">))</span> <span class="p">{</span>
<span class="k">auto</span> <span class="n">prim</span> <span class="o">=</span> <span class="n">createPrimitiveDesc</span><span class="p">(</span><span class="n">ctx</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">src</span> <span class="o">=</span> <span class="n">createMemory</span><span class="p">(</span><span class="n">memoryDesc</span><span class="p">(</span><span class="n">input</span><span class="o">-&gt;</span><span class="n">dims</span><span class="p">(),</span> <span class="n">actual_layout</span><span class="p">),</span> <span class="n">input</span><span class="o">-&gt;</span><span class="n">data</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">newbuffer</span> <span class="o">=</span> <span class="n">paddle</span><span class="o">::</span><span class="n">memory</span><span class="o">::</span><span class="n">Alloc</span><span class="p">(</span><span class="n">ctx</span><span class="p">.</span><span class="n">GetPlace</span><span class="p">(),</span> <span class="n">input</span><span class="o">-&gt;</span><span class="n">size_in_bytes</span><span class="p">());</span>
<span class="k">auto</span> <span class="n">dst</span> <span class="o">=</span> <span class="n">createMemory</span><span class="p">(</span><span class="n">p</span><span class="o">-&gt;</span><span class="n">expected_desc</span><span class="p">(),</span> <span class="n">newbuffer</span><span class="o">-&gt;</span><span class="n">data</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">reorder_primitive</span><span class="p">(</span><span class="k">new</span> <span class="n">mkldnn</span><span class="o">::</span><span class="n">reorder</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">dst</span><span class="p">));</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">addMemory</span><span class="p">(</span><span class="n">op_key</span><span class="o">+</span><span class="s">&quot;_src_input&quot;</span><span class="p">,</span> <span class="n">src</span><span class="p">);</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">addMemory</span><span class="p">(</span><span class="n">op_key</span><span class="o">+</span><span class="s">&quot;_input&quot;</span><span class="p">,</span> <span class="n">dst</span><span class="p">);</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">addPrimitive</span><span class="p">(</span><span class="n">op_key</span><span class="o">+</span><span class="s">&quot;_reorder_input&quot;</span><span class="p">,</span> <span class="n">reorder_primitive</span><span class="p">);</span>
<span class="p">}</span>
<span class="n">p</span> <span class="o">=</span> <span class="n">dev_ctx</span><span class="p">.</span><span class="n">findPrimitive</span><span class="p">(</span><span class="n">op_key</span> <span class="o">+</span> <span class="s">&quot;_reorder_input&quot;</span><span class="p">);</span>
<span class="n">PADDLE_ENFORCE</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="s">&quot;Should have Reorder Primitive&quot;</span><span class="p">);</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">submit</span><span class="p">(</span><span class="n">p</span><span class="p">);</span>
<span class="k">if</span> <span class="p">(</span><span class="o">!</span> <span class="k">this</span><span class="o">-&gt;</span><span class="n">isMKLDNNKernel</span><span class="p">())</span> <span class="p">{</span>
<span class="c1">// execute immediately only if this is not mkldnn kernel function.</span>
<span class="c1">// otherwise, it can be executed with the operator primitive in Compute</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">stream</span><span class="p">();</span>
<span class="p">}</span>
<span class="c1">// after submit, the input tensor in ExecutionContext should be changed as the converted one</span>
<span class="c1">// there should be another mechanism to ensure this</span>
<span class="p">}</span>
</pre></div>
</div>
</div>
<div class="section" id="unit-test">
<span id="unit-test"></span><h3>Unit Test<a class="headerlink" href="#unit-test" title="Permalink to this headline"></a></h3>
<p>All the functions should be tested corresponding.
TBD</p>
</div>
</div>
</div>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<p>
&copy; Copyright 2016, PaddlePaddle developers.
</p>
</div>
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT:'../../',
VERSION:'',
COLLAPSE_INDEX:false,
FILE_SUFFIX:'.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: ".txt",
};
</script>
<script type="text/javascript" src="../../_static/jquery.js"></script>
<script type="text/javascript" src="../../_static/underscore.js"></script>
<script type="text/javascript" src="../../_static/doctools.js"></script>
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<script src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/js/bootstrap.min.js" integrity="sha384-Tc5IQib027qvyjSMfHjOMaLkfuWVxZxUPnCJA7l2mCWNIpG9mGCD8wGNIcPD7Txa" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/perfect-scrollbar/0.6.14/js/perfect-scrollbar.jquery.min.js"></script>
<script src="../../_static/js/paddle_doc_init.js"></script>
</body>
</html>
\ No newline at end of file
因为 它太大了无法显示 source diff 。你可以改为 查看blob
# Design Doc: Add MKLDNN Kernel in Fluid Operator
## Principles
First of all, we should follow some basical principles like:
1. [How to write a new operator](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_en.md). We are trying to add a new kind of kernel into operators, so basically we should follow this doc.
2. [Supporting new Device/Library](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/support_new_device.md). Since MKLDNN is a new library to fluid, we should add `MKLDNNDeviceContext` and maybe `mkldnn_helper.h`, just like [cudnn_helper.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/cudnn_helper.h).
3. [Switch Kernel](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md). Another important point is that we should ensure the data synchronization between different kernel types, which is this [topic](https://github.com/PaddlePaddle/Paddle/issues/6549). So basically we should override `GetExpectedKernelType` and `trans` functions to support switching kernels.
4. [The Keys of Operator Kernel Type](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md). Kernel Type is a pivotal conception which can record the `Place`, `Library`, `DataType` and `Layout`.
## Sulution
In general, there are four parts we should follow to run a MKL-DNN primitive.
- Create a primitive descriptor that describe this operator
- Create a primitive itself by primitive descriptor and the engine
- Create all memory buffers that primitive needed
- Launch a stream to execute the primitive created
More details can refer to [here](http://01org.github.io/mkl-dnn).
It's better to avoid reinitialization of primitives and memory handles in the first three stages in every iteration. \
So we plan to create a map to record all the `primitive` and `memory`, which should not take too much memories as discussed [here](https://github.com/PaddlePaddle/Paddle/issues/6822).
It's assumed that following three conditions should be satisfied.
1. there is a unique key for each operator instance. May be the actual name of `Output Tensor`.
2. the `Input Tensor` inside `Compute` function is the one after converted.
3. we can get the phase(eg. `is_test`) inside `Compute` function, otherwise we need to expose this attribue to user.
### Compute
The algorithm of `Compute` would be described as follow, let's take conv like an example.
```c++
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace.");
PADDLE_ENFORCE(platform::is_mkldnn_library(ctx.GetLibrary()), "It must use MKLDNN Library.");
auto& dev_ctx = ctx.template device_context<platform::MKLDNNDeviceContext>();
// find primitive by unique key from mkldnn context
// the op_key should be a unique name of this op instance
auto& p = dev_ctx.findPrimitive(op_key + "_fwd");
// assuming the input tensor inside this compute function is the one after converted
// this point should be guarantee by another mechanism
auto& i = dev_ctx.findMemory(op_key + "_input");
if (p == nullptr || i == nullptr || inputSizeChanged(p, i)) {
auto fwd_primitive_desc = createPrimitiveDesc(ctx);
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* output = ctx.Output<Tensor>("Output");
shared_ptr<mkldnn::memory> in(new mkldnn::memory(fwd_primitive_desc->src_primitive_desc(), input->data<T>()));
shared_ptr<mkldnn::memory> wgt(new mkldnn::memory(fwd_primitive_desc->weights_primitive_desc(), filter->data<T>()));
shared_ptr<mkldnn::memory> out(new mkldnn::memory(fwd_primitive_desc->dst_primitive_desc(), output->mutable_data<T>(ctx.GetPlace())));
shared_ptr<mkldnn::conv_fwd> fwd_primitive(new mkldnn::conv_fwd(*fwd_primitive_desc, *in, *wgt, *out));
dev_ctx.addMemory(op_key+"_input", in);
dev_ctx.addMemory(op_key+"_output", out);
dev_ctx.addMemory(op_key+"_filer", wgt);
dev_ctx.addPrimitive(op_key+"_fwd", fwd_primitive);
dev_ctx.addPrimitiveDesc(op_key+"_fwd_PD", fwd_primitive_desc);
}
p = dev_ctx.findPrimitive(op_key + "_fwd");
PADDLE_ENFORCE(p, "Should have forward Primitive");
PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_input"), "Should have input memory");
PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_output"), "Should have output memory");
PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_filter"), "Should have filter memory");
PADDLE_ENFORCE(dev_ctx.findPrimitiveDesc(op_unique_key+"_fwd_PD"), "Should have forward PrimitiveDesc");
dev_ctx.submit(p);
dev_ctx.execute(); // the convert primitive should have already contained.
```
The `createPrimitiveDesc` returns the primitive descripotor of this operator, would be like this:
```c++
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* output = ctx.Output<Tensor>("Output");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
algorithm algo = static_cast<algorithm>(ctx.Attr<int>("convolution_algorithm_option"));
prop_kind pk = ctx.Attr<bool>("is_test") ? prop_kind::forward_inference : prop_kind::forward_training;
auto fwd_desc = mkldnn::conv_fwd::desc(/* all the setting above*/);
shared_ptr<mkldnn::conv_fwd::primitive_desc> fwd_primitive_desc(new mkldnn::conv_fwd::primitive_desc(fwd_desc, ctx.getEngine()));
return fwd_primitive_desc;
}
```
### MKLDNNDeviceContext
`MKLDNNDeviceContext`, which is very straightforward, should contain some base information like: `stream`, `engine` and the map needed.
### mkldnn_helper
Some functions would be put in `paddle/platform/mkldnn_helper.h`.
- create MKLDNN memories
- create MKLDNN primitives
- error check function
- etc
### Kernel Switch
We should `reorder` the different Layout from other device or to other device. `GetExpectedKernelType` and `trans` functions can help us to implement it.
`GetExpectedKernelType` should get the context, and this operator can return the best `KernelType`.
`trans` would be like this:
```c++
void trans(inputs, ctx) override {
if (NoNeedTrans()) {
return;
}
// find reorder primitive by op_key from context
auto& dev_ctx = ctx.template device_context<platform::MKLDNNDeviceContext>();
auto& p = dev_ctx.findPrimitive(op_key + "_reorder_input");
auto& i = dev_ctx.findMemory(op_key + "_src_input");
if (p == nullptr || i == nullptr || changeSized(i, input)) {
auto prim = createPrimitiveDesc(ctx);
auto src = createMemory(memoryDesc(input->dims(), actual_layout), input->data);
auto newbuffer = paddle::memory::Alloc(ctx.GetPlace(), input->size_in_bytes());
auto dst = createMemory(p->expected_desc(), newbuffer->data);
auto reorder_primitive(new mkldnn::reorder(src, dst));
dev_ctx.addMemory(op_key+"_src_input", src);
dev_ctx.addMemory(op_key+"_input", dst);
dev_ctx.addPrimitive(op_key+"_reorder_input", reorder_primitive);
}
p = dev_ctx.findPrimitive(op_key + "_reorder_input");
PADDLE_ENFORCE(p, "Should have Reorder Primitive");
dev_ctx.submit(p);
if (! this->isMKLDNNKernel()) {
// execute immediately only if this is not mkldnn kernel function.
// otherwise, it can be executed with the operator primitive in Compute
dev_ctx.stream();
}
// after submit, the input tensor in ExecutionContext should be changed as the converted one
// there should be another mechanism to ensure this
}
```
### Unit Test
All the functions should be tested corresponding.
TBD
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Design Doc: Add MKLDNN Kernel in Fluid Operator &mdash; PaddlePaddle 文档</title>
<link rel="stylesheet" href="../../_static/css/theme.css" type="text/css" />
<link rel="index" title="索引"
href="../../genindex.html"/>
<link rel="search" title="搜索" href="../../search.html"/>
<link rel="top" title="PaddlePaddle 文档" href="../../index.html"/>
<link rel="stylesheet" href="https://cdn.jsdelivr.net/perfect-scrollbar/0.6.14/css/perfect-scrollbar.min.css" type="text/css" />
<link rel="stylesheet" href="../../_static/css/override.css" type="text/css" />
<script>
var _hmt = _hmt || [];
(function() {
var hm = document.createElement("script");
hm.src = "//hm.baidu.com/hm.js?b9a314ab40d04d805655aab1deee08ba";
var s = document.getElementsByTagName("script")[0];
s.parentNode.insertBefore(hm, s);
})();
</script>
<script src="../../_static/js/modernizr.min.js"></script>
</head>
<body class="wy-body-for-nav" role="document">
<header class="site-header">
<div class="site-logo">
<a href="/"><img src="../../_static/images/PP_w.png"></a>
</div>
<div class="site-nav-links">
<div class="site-menu">
<a class="fork-on-github" href="https://github.com/PaddlePaddle/Paddle" target="_blank"><i class="fa fa-github"></i>Fork me on Github</a>
<div class="language-switcher dropdown">
<a type="button" data-toggle="dropdown">
<span>English</span>
<i class="fa fa-angle-up"></i>
<i class="fa fa-angle-down"></i>
</a>
<ul class="dropdown-menu">
<li><a href="/doc_cn">中文</a></li>
<li><a href="/doc">English</a></li>
</ul>
</div>
<ul class="site-page-links">
<li><a href="/">Home</a></li>
</ul>
</div>
<div class="doc-module">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../getstarted/index_cn.html">新手入门</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../howto/index_cn.html">进阶指南</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../api/index_cn.html">API</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../faq/index_cn.html">FAQ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../mobile/index_cn.html">MOBILE</a></li>
</ul>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../../search.html" method="get">
<input type="text" name="q" placeholder="Search docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
</div>
</header>
<div class="main-content-wrap">
<nav class="doc-menu-vertical" role="navigation">
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../getstarted/index_cn.html">新手入门</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../getstarted/build_and_install/index_cn.html">安装与编译</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../getstarted/build_and_install/pip_install_cn.html">使用pip安装</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../getstarted/build_and_install/docker_install_cn.html">使用Docker安装运行</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../howto/dev/build_cn.html">用Docker编译和测试PaddlePaddle</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../getstarted/build_and_install/build_from_source_cn.html">从源码编译</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../getstarted/concepts/use_concepts_cn.html">基本使用概念</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../howto/index_cn.html">进阶指南</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../howto/usage/cmd_parameter/index_cn.html">设置命令行参数</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../howto/usage/cmd_parameter/use_case_cn.html">使用案例</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../howto/usage/cmd_parameter/arguments_cn.html">参数概述</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../howto/usage/cmd_parameter/detail_introduction_cn.html">细节描述</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../howto/usage/cluster/cluster_train_cn.html">分布式训练</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../howto/usage/cluster/fabric_cn.html">fabric集群</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../howto/usage/cluster/openmpi_cn.html">openmpi集群</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../howto/usage/cluster/k8s_cn.html">kubernetes单机</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../howto/usage/cluster/k8s_distributed_cn.html">kubernetes distributed分布式</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../howto/usage/cluster/k8s_aws_cn.html">AWS上运行kubernetes集群训练</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../howto/dev/contribute_to_paddle_cn.html">如何贡献代码</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../howto/dev/write_docs_cn.html">如何贡献/修改文档</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../howto/deep_model/rnn/index_cn.html">RNN相关模型</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../howto/deep_model/rnn/rnn_config_cn.html">RNN配置</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../howto/deep_model/rnn/recurrent_group_cn.html">Recurrent Group教程</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../howto/deep_model/rnn/hierarchical_layer_cn.html">支持双层序列作为输入的Layer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../howto/deep_model/rnn/hrnn_rnn_api_compare_cn.html">单双层RNN API对比介绍</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../howto/optimization/gpu_profiling_cn.html">GPU性能分析与调优</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../api/index_cn.html">API</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../api/v2/model_configs.html">模型配置</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/config/activation.html">Activation</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/config/layer.html">Layers</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/config/evaluators.html">Evaluators</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/config/optimizer.html">Optimizer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/config/pooling.html">Pooling</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/config/networks.html">Networks</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/config/attr.html">Parameter Attribute</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../api/v2/data.html">数据访问</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/data/data_reader.html">Data Reader Interface</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/data/image.html">Image Interface</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/data/dataset.html">Dataset</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../api/v2/run_logic.html">训练与应用</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../api/v2/fluid.html">Fluid</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/layers.html">Layers</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/data_feeder.html">DataFeeder</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/executor.html">Executor</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/initializer.html">Initializer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/evaluator.html">Evaluator</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/nets.html">Nets</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/optimizer.html">Optimizer</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/param_attr.html">ParamAttr</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/profiler.html">Profiler</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../api/v2/fluid/regularizer.html">Regularizer</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../faq/index_cn.html">FAQ</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../faq/build_and_install/index_cn.html">编译安装与单元测试</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../faq/model/index_cn.html">模型配置</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../faq/parameter/index_cn.html">参数设置</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../faq/local/index_cn.html">本地训练与预测</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../faq/cluster/index_cn.html">集群训练与预测</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../mobile/index_cn.html">MOBILE</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../mobile/cross_compiling_for_android_cn.html">Android平台编译指南</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../mobile/cross_compiling_for_ios_cn.html">iOS平台编译指南</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../mobile/cross_compiling_for_raspberry_cn.html">Raspberry Pi平台编译指南</a></li>
</ul>
</li>
</ul>
</nav>
<section class="doc-content-wrap">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="wy-breadcrumbs">
<li>Design Doc: Add MKLDNN Kernel in Fluid Operator</li>
</ul>
</div>
<div class="wy-nav-content" id="doc-content">
<div class="rst-content">
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
<div itemprop="articleBody">
<div class="section" id="design-doc-add-mkldnn-kernel-in-fluid-operator">
<span id="design-doc-add-mkldnn-kernel-in-fluid-operator"></span><h1>Design Doc: Add MKLDNN Kernel in Fluid Operator<a class="headerlink" href="#design-doc-add-mkldnn-kernel-in-fluid-operator" title="永久链接至标题"></a></h1>
<div class="section" id="principles">
<span id="principles"></span><h2>Principles<a class="headerlink" href="#principles" title="永久链接至标题"></a></h2>
<p>First of all, we should follow some basical principles like:</p>
<ol class="simple">
<li><a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_en.md">How to write a new operator</a>. We are trying to add a new kind of kernel into operators, so basically we should follow this doc.</li>
<li><a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/support_new_device.md">Supporting new Device/Library</a>. Since MKLDNN is a new library to fluid, we should add <code class="docutils literal"><span class="pre">MKLDNNDeviceContext</span></code> and maybe <code class="docutils literal"><span class="pre">mkldnn_helper.h</span></code>, just like <a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/cudnn_helper.h">cudnn_helper.h</a>.</li>
<li><a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md">Switch Kernel</a>. Another important point is that we should ensure the data synchronization between different kernel types, which is this <a class="reference external" href="https://github.com/PaddlePaddle/Paddle/issues/6549">topic</a>. So basically we should override <code class="docutils literal"><span class="pre">GetExpectedKernelType</span></code> and <code class="docutils literal"><span class="pre">trans</span></code> functions to support switching kernels.</li>
<li><a class="reference external" href="https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md">The Keys of Operator Kernel Type</a>. Kernel Type is a pivotal conception which can record the <code class="docutils literal"><span class="pre">Place</span></code>, <code class="docutils literal"><span class="pre">Library</span></code>, <code class="docutils literal"><span class="pre">DataType</span></code> and <code class="docutils literal"><span class="pre">Layout</span></code>.</li>
</ol>
</div>
<div class="section" id="sulution">
<span id="sulution"></span><h2>Sulution<a class="headerlink" href="#sulution" title="永久链接至标题"></a></h2>
<p>In general, there are four parts we should follow to run a MKL-DNN primitive.</p>
<ul class="simple">
<li>Create a primitive descriptor that describe this operator</li>
<li>Create a primitive itself by primitive descriptor and the engine</li>
<li>Create all memory buffers that primitive needed</li>
<li>Launch a stream to execute the primitive created
More details can refer to <a class="reference external" href="http://01org.github.io/mkl-dnn">here</a>.</li>
</ul>
<p>It&#8217;s better to avoid reinitialization of primitives and memory handles in the first three stages in every iteration. So we plan to create a map to record all the <code class="docutils literal"><span class="pre">primitive</span></code> and <code class="docutils literal"><span class="pre">memory</span></code>, which should not take too much memories as discussed <a class="reference external" href="https://github.com/PaddlePaddle/Paddle/issues/6822">here</a>.</p>
<p>It&#8217;s assumed that following three conditions should be satisfied.</p>
<ol class="simple">
<li>there is a unique key for each operator instance. May be the actual name of <code class="docutils literal"><span class="pre">Output</span> <span class="pre">Tensor</span></code>.</li>
<li>the <code class="docutils literal"><span class="pre">Input</span> <span class="pre">Tensor</span></code> inside <code class="docutils literal"><span class="pre">Compute</span></code> function is the one after converted.</li>
<li>we can get the phase(eg. <code class="docutils literal"><span class="pre">is_test</span></code>) inside <code class="docutils literal"><span class="pre">Compute</span></code> function, otherwise we need to expose this attribue to user.</li>
</ol>
<div class="section" id="compute">
<span id="compute"></span><h3>Compute<a class="headerlink" href="#compute" title="永久链接至标题"></a></h3>
<p>The algorithm of <code class="docutils literal"><span class="pre">Compute</span></code> would be described as follow, let&#8217;s take conv like an example.</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="n">PADDLE_ENFORCE</span><span class="p">(</span><span class="n">platform</span><span class="o">::</span><span class="n">is_cpu_place</span><span class="p">(</span><span class="n">ctx</span><span class="p">.</span><span class="n">GetPlace</span><span class="p">()),</span> <span class="s">&quot;It must use CPUPlace.&quot;</span><span class="p">);</span>
<span class="n">PADDLE_ENFORCE</span><span class="p">(</span><span class="n">platform</span><span class="o">::</span><span class="n">is_mkldnn_library</span><span class="p">(</span><span class="n">ctx</span><span class="p">.</span><span class="n">GetLibrary</span><span class="p">()),</span> <span class="s">&quot;It must use MKLDNN Library.&quot;</span><span class="p">);</span>
<span class="k">auto</span><span class="o">&amp;</span> <span class="n">dev_ctx</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="k">template</span> <span class="n">device_context</span><span class="o">&lt;</span><span class="n">platform</span><span class="o">::</span><span class="n">MKLDNNDeviceContext</span><span class="o">&gt;</span><span class="p">();</span>
<span class="c1">// find primitive by unique key from mkldnn context</span>
<span class="c1">// the op_key should be a unique name of this op instance</span>
<span class="k">auto</span><span class="o">&amp;</span> <span class="n">p</span> <span class="o">=</span> <span class="n">dev_ctx</span><span class="p">.</span><span class="n">findPrimitive</span><span class="p">(</span><span class="n">op_key</span> <span class="o">+</span> <span class="s">&quot;_fwd&quot;</span><span class="p">);</span>
<span class="c1">// assuming the input tensor inside this compute function is the one after converted</span>
<span class="c1">// this point should be guarantee by another mechanism</span>
<span class="k">auto</span><span class="o">&amp;</span> <span class="n">i</span> <span class="o">=</span> <span class="n">dev_ctx</span><span class="p">.</span><span class="n">findMemory</span><span class="p">(</span><span class="n">op_key</span> <span class="o">+</span> <span class="s">&quot;_input&quot;</span><span class="p">);</span>
<span class="k">if</span> <span class="p">(</span><span class="n">p</span> <span class="o">==</span> <span class="k">nullptr</span> <span class="o">||</span> <span class="n">i</span> <span class="o">==</span> <span class="k">nullptr</span> <span class="o">||</span> <span class="n">inputSizeChanged</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span> <span class="p">{</span>
<span class="k">auto</span> <span class="n">fwd_primitive_desc</span> <span class="o">=</span> <span class="n">createPrimitiveDesc</span><span class="p">(</span><span class="n">ctx</span><span class="p">);</span>
<span class="k">auto</span><span class="o">*</span> <span class="n">input</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Input</span><span class="o">&lt;</span><span class="n">Tensor</span><span class="o">&gt;</span><span class="p">(</span><span class="s">&quot;Input&quot;</span><span class="p">);</span>
<span class="k">auto</span><span class="o">*</span> <span class="n">filter</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Input</span><span class="o">&lt;</span><span class="n">Tensor</span><span class="o">&gt;</span><span class="p">(</span><span class="s">&quot;Filter&quot;</span><span class="p">);</span>
<span class="k">auto</span><span class="o">*</span> <span class="n">output</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Output</span><span class="o">&lt;</span><span class="n">Tensor</span><span class="o">&gt;</span><span class="p">(</span><span class="s">&quot;Output&quot;</span><span class="p">);</span>
<span class="n">shared_ptr</span><span class="o">&lt;</span><span class="n">mkldnn</span><span class="o">::</span><span class="n">memory</span><span class="o">&gt;</span> <span class="n">in</span><span class="p">(</span><span class="k">new</span> <span class="n">mkldnn</span><span class="o">::</span><span class="n">memory</span><span class="p">(</span><span class="n">fwd_primitive_desc</span><span class="o">-&gt;</span><span class="n">src_primitive_desc</span><span class="p">(),</span> <span class="n">input</span><span class="o">-&gt;</span><span class="n">data</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">()));</span>
<span class="n">shared_ptr</span><span class="o">&lt;</span><span class="n">mkldnn</span><span class="o">::</span><span class="n">memory</span><span class="o">&gt;</span> <span class="n">wgt</span><span class="p">(</span><span class="k">new</span> <span class="n">mkldnn</span><span class="o">::</span><span class="n">memory</span><span class="p">(</span><span class="n">fwd_primitive_desc</span><span class="o">-&gt;</span><span class="n">weights_primitive_desc</span><span class="p">(),</span> <span class="n">filter</span><span class="o">-&gt;</span><span class="n">data</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">()));</span>
<span class="n">shared_ptr</span><span class="o">&lt;</span><span class="n">mkldnn</span><span class="o">::</span><span class="n">memory</span><span class="o">&gt;</span> <span class="n">out</span><span class="p">(</span><span class="k">new</span> <span class="n">mkldnn</span><span class="o">::</span><span class="n">memory</span><span class="p">(</span><span class="n">fwd_primitive_desc</span><span class="o">-&gt;</span><span class="n">dst_primitive_desc</span><span class="p">(),</span> <span class="n">output</span><span class="o">-&gt;</span><span class="n">mutable_data</span><span class="o">&lt;</span><span class="n">T</span><span class="o">&gt;</span><span class="p">(</span><span class="n">ctx</span><span class="p">.</span><span class="n">GetPlace</span><span class="p">())));</span>
<span class="n">shared_ptr</span><span class="o">&lt;</span><span class="n">mkldnn</span><span class="o">::</span><span class="n">conv_fwd</span><span class="o">&gt;</span> <span class="n">fwd_primitive</span><span class="p">(</span><span class="k">new</span> <span class="n">mkldnn</span><span class="o">::</span><span class="n">conv_fwd</span><span class="p">(</span><span class="o">*</span><span class="n">fwd_primitive_desc</span><span class="p">,</span> <span class="o">*</span><span class="n">in</span><span class="p">,</span> <span class="o">*</span><span class="n">wgt</span><span class="p">,</span> <span class="o">*</span><span class="n">out</span><span class="p">));</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">addMemory</span><span class="p">(</span><span class="n">op_key</span><span class="o">+</span><span class="s">&quot;_input&quot;</span><span class="p">,</span> <span class="n">in</span><span class="p">);</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">addMemory</span><span class="p">(</span><span class="n">op_key</span><span class="o">+</span><span class="s">&quot;_output&quot;</span><span class="p">,</span> <span class="n">out</span><span class="p">);</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">addMemory</span><span class="p">(</span><span class="n">op_key</span><span class="o">+</span><span class="s">&quot;_filer&quot;</span><span class="p">,</span> <span class="n">wgt</span><span class="p">);</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">addPrimitive</span><span class="p">(</span><span class="n">op_key</span><span class="o">+</span><span class="s">&quot;_fwd&quot;</span><span class="p">,</span> <span class="n">fwd_primitive</span><span class="p">);</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">addPrimitiveDesc</span><span class="p">(</span><span class="n">op_key</span><span class="o">+</span><span class="s">&quot;_fwd_PD&quot;</span><span class="p">,</span> <span class="n">fwd_primitive_desc</span><span class="p">);</span>
<span class="p">}</span>
<span class="n">p</span> <span class="o">=</span> <span class="n">dev_ctx</span><span class="p">.</span><span class="n">findPrimitive</span><span class="p">(</span><span class="n">op_key</span> <span class="o">+</span> <span class="s">&quot;_fwd&quot;</span><span class="p">);</span>
<span class="n">PADDLE_ENFORCE</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="s">&quot;Should have forward Primitive&quot;</span><span class="p">);</span>
<span class="n">PADDLE_ENFORCE</span><span class="p">(</span><span class="n">dev_ctx</span><span class="p">.</span><span class="n">findMemory</span><span class="p">(</span><span class="n">op_unique_key</span><span class="o">+</span><span class="s">&quot;_input&quot;</span><span class="p">),</span> <span class="s">&quot;Should have input memory&quot;</span><span class="p">);</span>
<span class="n">PADDLE_ENFORCE</span><span class="p">(</span><span class="n">dev_ctx</span><span class="p">.</span><span class="n">findMemory</span><span class="p">(</span><span class="n">op_unique_key</span><span class="o">+</span><span class="s">&quot;_output&quot;</span><span class="p">),</span> <span class="s">&quot;Should have output memory&quot;</span><span class="p">);</span>
<span class="n">PADDLE_ENFORCE</span><span class="p">(</span><span class="n">dev_ctx</span><span class="p">.</span><span class="n">findMemory</span><span class="p">(</span><span class="n">op_unique_key</span><span class="o">+</span><span class="s">&quot;_filter&quot;</span><span class="p">),</span> <span class="s">&quot;Should have filter memory&quot;</span><span class="p">);</span>
<span class="n">PADDLE_ENFORCE</span><span class="p">(</span><span class="n">dev_ctx</span><span class="p">.</span><span class="n">findPrimitiveDesc</span><span class="p">(</span><span class="n">op_unique_key</span><span class="o">+</span><span class="s">&quot;_fwd_PD&quot;</span><span class="p">),</span> <span class="s">&quot;Should have forward PrimitiveDesc&quot;</span><span class="p">);</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">submit</span><span class="p">(</span><span class="n">p</span><span class="p">);</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">execute</span><span class="p">();</span> <span class="c1">// the convert primitive should have already contained.</span>
</pre></div>
</div>
<p>The <code class="docutils literal"><span class="pre">createPrimitiveDesc</span></code> returns the primitive descripotor of this operator, would be like this:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span> <span class="k">auto</span><span class="o">*</span> <span class="n">input</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Input</span><span class="o">&lt;</span><span class="n">Tensor</span><span class="o">&gt;</span><span class="p">(</span><span class="s">&quot;Input&quot;</span><span class="p">);</span>
<span class="k">auto</span><span class="o">*</span> <span class="n">filter</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Input</span><span class="o">&lt;</span><span class="n">Tensor</span><span class="o">&gt;</span><span class="p">(</span><span class="s">&quot;Filter&quot;</span><span class="p">);</span>
<span class="k">auto</span><span class="o">*</span> <span class="n">output</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Output</span><span class="o">&lt;</span><span class="n">Tensor</span><span class="o">&gt;</span><span class="p">(</span><span class="s">&quot;Output&quot;</span><span class="p">);</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="n">strides</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Attr</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&gt;</span><span class="p">(</span><span class="s">&quot;strides&quot;</span><span class="p">);</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="n">paddings</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Attr</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&gt;</span><span class="p">(</span><span class="s">&quot;paddings&quot;</span><span class="p">);</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span> <span class="n">dilations</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Attr</span><span class="o">&lt;</span><span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;&gt;</span><span class="p">(</span><span class="s">&quot;dilations&quot;</span><span class="p">);</span>
<span class="kt">int</span> <span class="n">groups</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Attr</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span><span class="p">(</span><span class="s">&quot;groups&quot;</span><span class="p">);</span>
<span class="n">algorithm</span> <span class="n">algo</span> <span class="o">=</span> <span class="k">static_cast</span><span class="o">&lt;</span><span class="n">algorithm</span><span class="o">&gt;</span><span class="p">(</span><span class="n">ctx</span><span class="p">.</span><span class="n">Attr</span><span class="o">&lt;</span><span class="kt">int</span><span class="o">&gt;</span><span class="p">(</span><span class="s">&quot;convolution_algorithm_option&quot;</span><span class="p">));</span>
<span class="n">prop_kind</span> <span class="n">pk</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="n">Attr</span><span class="o">&lt;</span><span class="kt">bool</span><span class="o">&gt;</span><span class="p">(</span><span class="s">&quot;is_test&quot;</span><span class="p">)</span> <span class="o">?</span> <span class="n">prop_kind</span><span class="o">::</span><span class="nl">forward_inference</span> <span class="p">:</span> <span class="n">prop_kind</span><span class="o">::</span><span class="n">forward_training</span><span class="p">;</span>
<span class="k">auto</span> <span class="n">fwd_desc</span> <span class="o">=</span> <span class="n">mkldnn</span><span class="o">::</span><span class="n">conv_fwd</span><span class="o">::</span><span class="n">desc</span><span class="p">(</span><span class="cm">/* all the setting above*/</span><span class="p">);</span>
<span class="n">shared_ptr</span><span class="o">&lt;</span><span class="n">mkldnn</span><span class="o">::</span><span class="n">conv_fwd</span><span class="o">::</span><span class="n">primitive_desc</span><span class="o">&gt;</span> <span class="n">fwd_primitive_desc</span><span class="p">(</span><span class="k">new</span> <span class="n">mkldnn</span><span class="o">::</span><span class="n">conv_fwd</span><span class="o">::</span><span class="n">primitive_desc</span><span class="p">(</span><span class="n">fwd_desc</span><span class="p">,</span> <span class="n">ctx</span><span class="p">.</span><span class="n">getEngine</span><span class="p">()));</span>
<span class="k">return</span> <span class="n">fwd_primitive_desc</span><span class="p">;</span>
<span class="p">}</span>
</pre></div>
</div>
</div>
<div class="section" id="mkldnndevicecontext">
<span id="mkldnndevicecontext"></span><h3>MKLDNNDeviceContext<a class="headerlink" href="#mkldnndevicecontext" title="永久链接至标题"></a></h3>
<p><code class="docutils literal"><span class="pre">MKLDNNDeviceContext</span></code>, which is very straightforward, should contain some base information like: <code class="docutils literal"><span class="pre">stream</span></code>, <code class="docutils literal"><span class="pre">engine</span></code> and the map needed.</p>
</div>
<div class="section" id="mkldnn-helper">
<span id="mkldnn-helper"></span><h3>mkldnn_helper<a class="headerlink" href="#mkldnn-helper" title="永久链接至标题"></a></h3>
<p>Some functions would be put in <code class="docutils literal"><span class="pre">paddle/platform/mkldnn_helper.h</span></code>.</p>
<ul class="simple">
<li>create MKLDNN memories</li>
<li>create MKLDNN primitives</li>
<li>error check function</li>
<li>etc</li>
</ul>
</div>
<div class="section" id="kernel-switch">
<span id="kernel-switch"></span><h3>Kernel Switch<a class="headerlink" href="#kernel-switch" title="永久链接至标题"></a></h3>
<p>We should <code class="docutils literal"><span class="pre">reorder</span></code> the different Layout from other device or to other device. <code class="docutils literal"><span class="pre">GetExpectedKernelType</span></code> and <code class="docutils literal"><span class="pre">trans</span></code> functions can help us to implement it.</p>
<p><code class="docutils literal"><span class="pre">GetExpectedKernelType</span></code> should get the context, and this operator can return the best <code class="docutils literal"><span class="pre">KernelType</span></code>.
<code class="docutils literal"><span class="pre">trans</span></code> would be like this:</p>
<div class="highlight-c++"><div class="highlight"><pre><span></span><span class="kt">void</span> <span class="nf">trans</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">ctx</span><span class="p">)</span> <span class="k">override</span> <span class="p">{</span>
<span class="k">if</span> <span class="p">(</span><span class="n">NoNeedTrans</span><span class="p">())</span> <span class="p">{</span>
<span class="k">return</span><span class="p">;</span>
<span class="p">}</span>
<span class="c1">// find reorder primitive by op_key from context</span>
<span class="k">auto</span><span class="o">&amp;</span> <span class="n">dev_ctx</span> <span class="o">=</span> <span class="n">ctx</span><span class="p">.</span><span class="k">template</span> <span class="n">device_context</span><span class="o">&lt;</span><span class="n">platform</span><span class="o">::</span><span class="n">MKLDNNDeviceContext</span><span class="o">&gt;</span><span class="p">();</span>
<span class="k">auto</span><span class="o">&amp;</span> <span class="n">p</span> <span class="o">=</span> <span class="n">dev_ctx</span><span class="p">.</span><span class="n">findPrimitive</span><span class="p">(</span><span class="n">op_key</span> <span class="o">+</span> <span class="s">&quot;_reorder_input&quot;</span><span class="p">);</span>
<span class="k">auto</span><span class="o">&amp;</span> <span class="n">i</span> <span class="o">=</span> <span class="n">dev_ctx</span><span class="p">.</span><span class="n">findMemory</span><span class="p">(</span><span class="n">op_key</span> <span class="o">+</span> <span class="s">&quot;_src_input&quot;</span><span class="p">);</span>
<span class="k">if</span> <span class="p">(</span><span class="n">p</span> <span class="o">==</span> <span class="k">nullptr</span> <span class="o">||</span> <span class="n">i</span> <span class="o">==</span> <span class="k">nullptr</span> <span class="o">||</span> <span class="n">changeSized</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">input</span><span class="p">))</span> <span class="p">{</span>
<span class="k">auto</span> <span class="n">prim</span> <span class="o">=</span> <span class="n">createPrimitiveDesc</span><span class="p">(</span><span class="n">ctx</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">src</span> <span class="o">=</span> <span class="n">createMemory</span><span class="p">(</span><span class="n">memoryDesc</span><span class="p">(</span><span class="n">input</span><span class="o">-&gt;</span><span class="n">dims</span><span class="p">(),</span> <span class="n">actual_layout</span><span class="p">),</span> <span class="n">input</span><span class="o">-&gt;</span><span class="n">data</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">newbuffer</span> <span class="o">=</span> <span class="n">paddle</span><span class="o">::</span><span class="n">memory</span><span class="o">::</span><span class="n">Alloc</span><span class="p">(</span><span class="n">ctx</span><span class="p">.</span><span class="n">GetPlace</span><span class="p">(),</span> <span class="n">input</span><span class="o">-&gt;</span><span class="n">size_in_bytes</span><span class="p">());</span>
<span class="k">auto</span> <span class="n">dst</span> <span class="o">=</span> <span class="n">createMemory</span><span class="p">(</span><span class="n">p</span><span class="o">-&gt;</span><span class="n">expected_desc</span><span class="p">(),</span> <span class="n">newbuffer</span><span class="o">-&gt;</span><span class="n">data</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">reorder_primitive</span><span class="p">(</span><span class="k">new</span> <span class="n">mkldnn</span><span class="o">::</span><span class="n">reorder</span><span class="p">(</span><span class="n">src</span><span class="p">,</span> <span class="n">dst</span><span class="p">));</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">addMemory</span><span class="p">(</span><span class="n">op_key</span><span class="o">+</span><span class="s">&quot;_src_input&quot;</span><span class="p">,</span> <span class="n">src</span><span class="p">);</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">addMemory</span><span class="p">(</span><span class="n">op_key</span><span class="o">+</span><span class="s">&quot;_input&quot;</span><span class="p">,</span> <span class="n">dst</span><span class="p">);</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">addPrimitive</span><span class="p">(</span><span class="n">op_key</span><span class="o">+</span><span class="s">&quot;_reorder_input&quot;</span><span class="p">,</span> <span class="n">reorder_primitive</span><span class="p">);</span>
<span class="p">}</span>
<span class="n">p</span> <span class="o">=</span> <span class="n">dev_ctx</span><span class="p">.</span><span class="n">findPrimitive</span><span class="p">(</span><span class="n">op_key</span> <span class="o">+</span> <span class="s">&quot;_reorder_input&quot;</span><span class="p">);</span>
<span class="n">PADDLE_ENFORCE</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="s">&quot;Should have Reorder Primitive&quot;</span><span class="p">);</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">submit</span><span class="p">(</span><span class="n">p</span><span class="p">);</span>
<span class="k">if</span> <span class="p">(</span><span class="o">!</span> <span class="k">this</span><span class="o">-&gt;</span><span class="n">isMKLDNNKernel</span><span class="p">())</span> <span class="p">{</span>
<span class="c1">// execute immediately only if this is not mkldnn kernel function.</span>
<span class="c1">// otherwise, it can be executed with the operator primitive in Compute</span>
<span class="n">dev_ctx</span><span class="p">.</span><span class="n">stream</span><span class="p">();</span>
<span class="p">}</span>
<span class="c1">// after submit, the input tensor in ExecutionContext should be changed as the converted one</span>
<span class="c1">// there should be another mechanism to ensure this</span>
<span class="p">}</span>
</pre></div>
</div>
</div>
<div class="section" id="unit-test">
<span id="unit-test"></span><h3>Unit Test<a class="headerlink" href="#unit-test" title="永久链接至标题"></a></h3>
<p>All the functions should be tested corresponding.
TBD</p>
</div>
</div>
</div>
</div>
</div>
<footer>
<hr/>
<div role="contentinfo">
<p>
&copy; Copyright 2016, PaddlePaddle developers.
</p>
</div>
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</footer>
</div>
</div>
</section>
</div>
<script type="text/javascript">
var DOCUMENTATION_OPTIONS = {
URL_ROOT:'../../',
VERSION:'',
COLLAPSE_INDEX:false,
FILE_SUFFIX:'.html',
HAS_SOURCE: true,
SOURCELINK_SUFFIX: ".txt",
};
</script>
<script type="text/javascript" src="../../_static/jquery.js"></script>
<script type="text/javascript" src="../../_static/underscore.js"></script>
<script type="text/javascript" src="../../_static/doctools.js"></script>
<script type="text/javascript" src="../../_static/translations.js"></script>
<script type="text/javascript" src="https://cdn.bootcss.com/mathjax/2.7.0/MathJax.js"></script>
<script type="text/javascript" src="../../_static/js/theme.js"></script>
<script src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/js/bootstrap.min.js" integrity="sha384-Tc5IQib027qvyjSMfHjOMaLkfuWVxZxUPnCJA7l2mCWNIpG9mGCD8wGNIcPD7Txa" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/perfect-scrollbar/0.6.14/js/perfect-scrollbar.jquery.min.js"></script>
<script src="../../_static/js/paddle_doc_init.js"></script>
</body>
</html>
\ No newline at end of file
因为 它太大了无法显示 source diff 。你可以改为 查看blob
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册