提交 1415c0b0 编写于 作者: L Luo Tao

refine mnist html

上级 5d64fa7d
......@@ -46,10 +46,137 @@
<p>实际上,tanh函数只是规模变化的sigmoid函数,将sigmoid函数值放大2倍之后再向下平移1个单位:tanh(x) = 2sigmoid(2x) - 1 。</p></li>
<li><p>ReLU激活函数: <span class="MathJax_Preview"></span><span class="MathJax_SVG" id="MathJax-Element-44-Frame" role="textbox" aria-readonly="true" style="font-size: 100%; display: inline-block;"><svg xmlns:xlink="http://www.w3.org/1999/xlink" viewBox="0 -771.0516853480245 7513.722222222223 1042.103370696049" style="width: 17.49ex; height: 2.432ex; vertical-align: -0.695ex; margin: 1px 0px;"><g stroke="black" fill="black" stroke-width="0" transform="matrix(1 0 0 -1 0 0)"><use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#MJMATHI-66"></use><use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#MJMAIN-28" x="550" y="0"></use><use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#MJMATHI-78" x="940" y="0"></use><use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#MJMAIN-29" x="1512" y="0"></use><use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#MJMAIN-3D" x="2179" y="0"></use><use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#MJMATHI-6D" x="3236" y="0"></use><use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#MJMATHI-61" x="4114" y="0"></use><use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#MJMATHI-78" x="4644" y="0"></use><use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#MJMAIN-28" x="5216" y="0"></use><use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#MJMAIN-30" x="5606" y="0"></use><use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#MJMAIN-2C" x="6106" y="0"></use><use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#MJMATHI-78" x="6551" y="0"></use><use xmlns:xlink="http://www.w3.org/1999/xlink" xlink:href="#MJMAIN-29" x="7124" y="0"></use></g></svg></span><script type="math/tex" id="MathJax-Element-44"> f(x) = max(0, x) </script></p></li>
</ul><p data-anchor-id="tss0">更详细的介绍请参考<a href="https://en.wikipedia.org/wiki/Activation_function" target="_blank">维基百科激活函数</a></p><div class="md-section-divider"></div><h2 data-anchor-id="xbcu" id="数据准备">数据准备</h2><div class="md-section-divider"></div><h3 data-anchor-id="c26s" id="数据介绍与下载">数据介绍与下载</h3><p data-anchor-id="pim6">执行以下命令,下载<a href="http://yann.lecun.com/exdb/mnist/" target="_blank">MNIST</a>数据库并解压缩,然后将训练集和测试集的地址分别写入train.list和test.list两个文件,供PaddlePaddle读取。</p><div class="md-section-divider"></div><pre class="prettyprint linenums prettyprinted" data-anchor-id="1w26"><ol class="linenums"><li class="L0"><code class="language-bash"><span class="pun">./</span><span class="pln">data</span><span class="pun">/</span><span class="pln">get_mnist_data</span><span class="pun">.</span><span class="pln">sh</span></code></li></ol></pre><div class="md-section-divider"></div><pre class="prettyprint linenums prettyprinted" data-anchor-id="tfsi"><ol class="linenums"><li class="L0"><code class="language-bash"><span class="pun">./</span><span class="pln">load_data</span><span class="pun">.</span><span class="pln">py</span></code></li></ol></pre><div class="md-section-divider"></div><div class="md-section-divider"></div><pre class="prettyprint linenums prettyprinted" data-anchor-id="pysa"><ol class="linenums"><li class="L0"><code class="language-python"><span class="com"># Define a py data provider</span></code></li><li class="L1"><code class="language-python"><span class="lit">@provider</span><span class="pun">(</span></code></li><li class="L2"><code class="language-python"><span class="pln"> input_types</span><span class="pun">={</span><span class="str">'pixel'</span><span class="pun">:</span><span class="pln"> dense_vector</span><span class="pun">(</span><span class="lit">28</span><span class="pln"> </span><span class="pun">*</span><span class="pln"> </span><span class="lit">28</span><span class="pun">),</span></code></li><li class="L3"><code class="language-python"><span class="pln"> </span><span class="str">'label'</span><span class="pun">:</span><span class="pln"> integer_value</span><span class="pun">(</span><span class="lit">10</span><span class="pun">)})</span></code></li><li class="L4"><code class="language-python"><span class="kwd">def</span><span class="pln"> process</span><span class="pun">(</span><span class="pln">settings</span><span class="pun">,</span><span class="pln"> filename</span><span class="pun">):</span><span class="pln"> </span><span class="com"># settings is not used currently.</span></code></li><li class="L5"><code class="language-python"><span class="pln"> </span><span class="com"># 打开图片文件</span></code></li><li class="L6"><code class="language-python"><span class="pln"> </span><span class="kwd">with</span><span class="pln"> open</span><span class="pun">(</span><span class="pln"> filename </span><span class="pun">+</span><span class="pln"> </span><span class="str">"-images-idx3-ubyte"</span><span class="pun">,</span><span class="pln"> </span><span class="str">"rb"</span><span class="pun">)</span><span class="pln"> </span><span class="kwd">as</span><span class="pln"> f</span><span class="pun">:</span><span class="pln"> </span></code></li><li class="L7"><code class="language-python"><span class="pln"> </span><span class="com"># 读取开头的四个参数,magic代表数据的格式,n代表数据的总量,rows和cols分别代表行数和列数</span></code></li><li class="L8"><code class="language-python"><span class="pln"> magic</span><span class="pun">,</span><span class="pln"> n</span><span class="pun">,</span><span class="pln"> rows</span><span class="pun">,</span><span class="pln"> cols </span><span class="pun">=</span><span class="pln"> struct</span><span class="pun">.</span><span class="pln">upack</span><span class="pun">(</span><span class="str">"&gt;IIII"</span><span class="pun">,</span><span class="pln"> f</span><span class="pun">.</span><span class="pln">read</span><span class="pun">(</span><span class="lit">16</span><span class="pun">))</span><span class="pln"> </span></code></li><li class="L9"><code class="language-python"><span class="pln"> </span><span class="com"># 以无符号字节为单位一个一个的读取数据</span></code></li><li class="L0"><code class="language-python"><span class="pln"> images </span><span class="pun">=</span><span class="pln"> np</span><span class="pun">.</span><span class="pln">fromfile</span><span class="pun">(</span><span class="pln"> </span></code></li><li class="L1"><code class="language-python"><span class="pln"> f</span><span class="pun">,</span><span class="pln"> </span><span class="str">'ubyte'</span><span class="pun">,</span></code></li><li class="L2"><code class="language-python"><span class="pln"> count</span><span class="pun">=</span><span class="pln">n </span><span class="pun">*</span><span class="pln"> rows </span><span class="pun">*</span><span class="pln"> cols</span><span class="pun">).</span><span class="pln">reshape</span><span class="pun">(</span><span class="pln">n</span><span class="pun">,</span><span class="pln"> rows</span><span class="pun">,</span><span class="pln"> cols</span><span class="pun">).</span><span class="pln">astype</span><span class="pun">(</span><span class="str">'float32'</span><span class="pun">)</span></code></li><li class="L3"><code class="language-python"><span class="pln"> </span><span class="com"># 将0~255的数据归一化到[-1,1]的区间</span></code></li><li class="L4"><code class="language-python"><span class="pln"> images </span><span class="pun">=</span><span class="pln"> images </span><span class="pun">/</span><span class="pln"> </span><span class="lit">255.0</span><span class="pln"> </span><span class="pun">*</span><span class="pln"> </span><span class="lit">2.0</span><span class="pln"> </span><span class="pun">-</span><span class="pln"> </span><span class="lit">1.0</span><span class="pln"> </span></code></li><li class="L5"><code class="language-python"></code></li><li class="L6"><code class="language-python"></code></li><li class="L7"><code class="language-python"><span class="pln"> </span><span class="com"># 打开标签文件</span></code></li><li class="L8"><code class="language-python"><span class="pln"> </span><span class="kwd">with</span><span class="pln"> open</span><span class="pun">(</span><span class="pln"> filename </span><span class="pun">+</span><span class="pln"> </span><span class="str">"-labels-idx1-ubyte"</span><span class="pun">,</span><span class="pln"> </span><span class="str">"rb"</span><span class="pun">)</span><span class="pln"> </span><span class="kwd">as</span><span class="pln"> l</span><span class="pun">:</span><span class="pln"> </span></code></li><li class="L9"><code class="language-python"><span class="pln"> </span><span class="com"># 读取开头的两个参数</span></code></li><li class="L0"><code class="language-python"><span class="pln"> magic</span><span class="pun">,</span><span class="pln"> n </span><span class="pun">=</span><span class="pln"> struct</span><span class="pun">.</span><span class="pln">upack</span><span class="pun">(</span><span class="str">"&gt;II"</span><span class="pun">,</span><span class="pln"> l</span><span class="pun">.</span><span class="pln">read</span><span class="pun">(</span><span class="lit">8</span><span class="pun">))</span><span class="pln"> </span></code></li><li class="L1"><code class="language-python"><span class="pln"> </span><span class="com"># 以无符号字节为单位一个一个的读取数据</span></code></li><li class="L2"><code class="language-python"><span class="pln"> labels </span><span class="pun">=</span><span class="pln"> np</span><span class="pun">.</span><span class="pln">fromfile</span><span class="pun">(</span><span class="pln">l</span><span class="pun">,</span><span class="pln"> </span><span class="str">'ubyte'</span><span class="pun">,</span><span class="pln"> count</span><span class="pun">=</span><span class="pln">n</span><span class="pun">).</span><span class="pln">astype</span><span class="pun">(</span><span class="str">"int"</span><span class="pun">)</span><span class="pln"> </span></code></li><li class="L3"><code class="language-python"></code></li><li class="L4"><code class="language-python"><span class="pln"> </span><span class="kwd">for</span><span class="pln"> i </span><span class="kwd">in</span><span class="pln"> xrange</span><span class="pun">(</span><span class="pln">n</span><span class="pun">):</span></code></li><li class="L5"><code class="language-python"><span class="pln"> </span><span class="kwd">yield</span><span class="pln"> </span><span class="pun">{</span><span class="str">"pixel"</span><span class="pun">:</span><span class="pln"> images</span><span class="pun">[</span><span class="pln">i</span><span class="pun">,</span><span class="pln"> </span><span class="pun">:],</span><span class="pln"> </span><span class="str">'label'</span><span class="pun">:</span><span class="pln"> labels</span><span class="pun">[</span><span class="pln">i</span><span class="pun">]}</span></code></li></ol></pre><div class="md-section-divider"></div><div class="md-section-divider"></div><div class="md-section-divider"></div><pre class="prettyprint linenums prettyprinted" data-anchor-id="q1st"><ol class="linenums"><li class="L0"><code class="language-python"><span class="pln"> </span><span class="kwd">if</span><span class="pln"> </span><span class="kwd">not</span><span class="pln"> is_predict</span><span class="pun">:</span></code></li><li class="L1"><code class="language-python"><span class="pln"> data_dir </span><span class="pun">=</span><span class="pln"> </span><span class="str">'./data/'</span></code></li><li class="L2"><code class="language-python"><span class="pln"> define_py_data_sources2</span><span class="pun">(</span></code></li><li class="L3"><code class="language-python"><span class="pln"> train_list</span><span class="pun">=</span><span class="pln">data_dir </span><span class="pun">+</span><span class="pln"> </span><span class="str">'train.list'</span><span class="pun">,</span></code></li><li class="L4"><code class="language-python"><span class="pln"> test_list</span><span class="pun">=</span><span class="pln">data_dir </span><span class="pun">+</span><span class="pln"> </span><span class="str">'test.list'</span><span class="pun">,</span></code></li><li class="L5"><code class="language-python"><span class="pln"> module</span><span class="pun">=</span><span class="str">'mnist_provider'</span><span class="pun">,</span></code></li><li class="L6"><code class="language-python"><span class="pln"> obj</span><span class="pun">=</span><span class="str">'process'</span><span class="pun">)</span></code></li></ol></pre><div class="md-section-divider"></div><div class="md-section-divider"></div><pre class="prettyprint linenums prettyprinted" data-anchor-id="k2j8"><ol class="linenums"><li class="L0"><code class="language-python"><span class="pln">settings</span><span class="pun">(</span></code></li><li class="L1"><code class="language-python"><span class="pln"> batch_size</span><span class="pun">=</span><span class="lit">128</span><span class="pun">,</span></code></li><li class="L2"><code class="language-python"><span class="pln"> learning_rate</span><span class="pun">=</span><span class="lit">0.1</span><span class="pln"> </span><span class="pun">/</span><span class="pln"> </span><span class="lit">128.0</span><span class="pun">,</span></code></li><li class="L3"><code class="language-python"><span class="pln"> learning_method</span><span class="pun">=</span><span class="typ">MomentumOptimizer</span><span class="pun">(</span><span class="lit">0.9</span><span class="pun">),</span></code></li><li class="L4"><code class="language-python"><span class="pln"> regularization</span><span class="pun">=</span><span class="pln">L2Regularization</span><span class="pun">(</span><span class="lit">0.0005</span><span class="pln"> </span><span class="pun">*</span><span class="pln"> </span><span class="lit">128</span><span class="pun">))</span></code></li></ol></pre><div class="md-section-divider"></div><div class="md-section-divider"></div><div class="md-section-divider"></div><pre class="prettyprint linenums prettyprinted" data-anchor-id="shvv"><ol class="linenums"><li class="L0"><code class="language-python"><span class="pln">data_size </span><span class="pun">=</span><span class="pln"> </span><span class="lit">1</span><span class="pln"> </span><span class="pun">*</span><span class="pln"> </span><span class="lit">28</span><span class="pln"> </span><span class="pun">*</span><span class="pln"> </span><span class="lit">28</span></code></li><li class="L1"><code class="language-python"><span class="pln">label_size </span><span class="pun">=</span><span class="pln"> </span><span class="lit">10</span></code></li><li class="L2"><code class="language-python"><span class="pln">img </span><span class="pun">=</span><span class="pln"> data_layer</span><span class="pun">(</span><span class="pln">name</span><span class="pun">=</span><span class="str">'pixel'</span><span class="pun">,</span><span class="pln"> size</span><span class="pun">=</span><span class="pln">data_size</span><span class="pun">)</span></code></li><li class="L3"><code class="language-python"></code></li><li class="L4"><code class="language-python"><span class="pln">predict </span><span class="pun">=</span><span class="pln"> softmax_regression</span><span class="pun">(</span><span class="pln">img</span><span class="pun">)</span><span class="pln"> </span><span class="com"># Softmax回归</span></code></li><li class="L5"><code class="language-python"><span class="com">#predict = multilayer_perceptron(img) #多层感知器</span></code></li><li class="L6"><code class="language-python"><span class="com">#predict = convolutional_neural_network(img) #LeNet5卷积神经网络</span></code></li><li class="L7"><code class="language-python"></code></li><li class="L8"><code class="language-python"><span class="kwd">if</span><span class="pln"> </span><span class="kwd">not</span><span class="pln"> is_predict</span><span class="pun">:</span></code></li><li class="L9"><code class="language-python"><span class="pln"> lbl </span><span class="pun">=</span><span class="pln"> data_layer</span><span class="pun">(</span><span class="pln">name</span><span class="pun">=</span><span class="str">"label"</span><span class="pun">,</span><span class="pln"> size</span><span class="pun">=</span><span class="pln">label_size</span><span class="pun">)</span></code></li><li class="L0"><code class="language-python"><span class="pln"> inputs</span><span class="pun">(</span><span class="pln">img</span><span class="pun">,</span><span class="pln"> lbl</span><span class="pun">)</span></code></li><li class="L1"><code class="language-python"><span class="pln"> outputs</span><span class="pun">(</span><span class="pln">classification_cost</span><span class="pun">(</span><span class="pln">input</span><span class="pun">=</span><span class="pln">predict</span><span class="pun">,</span><span class="pln"> label</span><span class="pun">=</span><span class="pln">lbl</span><span class="pun">))</span></code></li><li class="L2"><code class="language-python"><span class="kwd">else</span><span class="pun">:</span></code></li><li class="L3"><code class="language-python"><span class="pln"> outputs</span><span class="pun">(</span><span class="pln">predict</span><span class="pun">)</span></code></li></ol></pre><div class="md-section-divider"></div><div class="md-section-divider"></div><pre class="prettyprint linenums prettyprinted" data-anchor-id="u3kh"><ol class="linenums"><li class="L0"><code class="language-python"><span class="kwd">def</span><span class="pln"> softmax_regression</span><span class="pun">(</span><span class="pln">img</span><span class="pun">):</span></code></li><li class="L1"><code class="language-python"><span class="pln"> predict </span><span class="pun">=</span><span class="pln"> fc_layer</span><span class="pun">(</span><span class="pln">input</span><span class="pun">=</span><span class="pln">img</span><span class="pun">,</span><span class="pln"> size</span><span class="pun">=</span><span class="lit">10</span><span class="pun">,</span><span class="pln"> act</span><span class="pun">=</span><span class="typ">SoftmaxActivation</span><span class="pun">())</span></code></li><li class="L2"><code class="language-python"><span class="pln"> </span><span class="kwd">return</span><span class="pln"> predict</span></code></li></ol></pre><div class="md-section-divider"></div><div class="md-section-divider"></div><pre class="prettyprint linenums prettyprinted" data-anchor-id="1198"><ol class="linenums"><li class="L0"><code class="language-python"><span class="kwd">def</span><span class="pln"> multilayer_perceptron</span><span class="pun">(</span><span class="pln">img</span><span class="pun">):</span></code></li><li class="L1"><code class="language-python"><span class="pln"> </span><span class="com"># 第一个全连接层,激活函数为ReLU</span></code></li><li class="L2"><code class="language-python"><span class="pln"> hidden1 </span><span class="pun">=</span><span class="pln"> fc_layer</span><span class="pun">(</span><span class="pln">input</span><span class="pun">=</span><span class="pln">img</span><span class="pun">,</span><span class="pln"> size</span><span class="pun">=</span><span class="lit">128</span><span class="pun">,</span><span class="pln"> act</span><span class="pun">=</span><span class="typ">ReluActivation</span><span class="pun">())</span></code></li><li class="L3"><code class="language-python"><span class="pln"> </span><span class="com"># 第二个全连接层,激活函数为ReLU</span></code></li><li class="L4"><code class="language-python"><span class="pln"> hidden2 </span><span class="pun">=</span><span class="pln"> fc_layer</span><span class="pun">(</span><span class="pln">input</span><span class="pun">=</span><span class="pln">hidden1</span><span class="pun">,</span><span class="pln"> size</span><span class="pun">=</span><span class="lit">64</span><span class="pun">,</span><span class="pln"> act</span><span class="pun">=</span><span class="typ">ReluActivation</span><span class="pun">())</span></code></li><li class="L5"><code class="language-python"><span class="pln"> </span><span class="com"># 以softmax为激活函数的全连接输出层,输出层的大小必须为数字的个数10</span></code></li><li class="L6"><code class="language-python"><span class="pln"> predict </span><span class="pun">=</span><span class="pln"> fc_layer</span><span class="pun">(</span><span class="pln">input</span><span class="pun">=</span><span class="pln">hidden2</span><span class="pun">,</span><span class="pln"> size</span><span class="pun">=</span><span class="lit">10</span><span class="pun">,</span><span class="pln"> act</span><span class="pun">=</span><span class="typ">SoftmaxActivation</span><span class="pun">())</span></code></li><li class="L7"><code class="language-python"><span class="pln"> </span><span class="kwd">return</span><span class="pln"> predict</span></code></li></ol></pre><div class="md-section-divider"></div><div class="md-section-divider"></div><pre class="prettyprint linenums prettyprinted" data-anchor-id="acw1"><ol class="linenums"><li class="L0"><code class="language-python"><span class="kwd">def</span><span class="pln"> convolutional_neural_network</span><span class="pun">(</span><span class="pln">img</span><span class="pun">):</span></code></li><li class="L1"><code class="language-python"><span class="pln"> </span><span class="com"># 第一个卷积-池化层</span></code></li><li class="L2"><code class="language-python"><span class="pln"> conv_pool_1 </span><span class="pun">=</span><span class="pln"> simple_img_conv_pool</span><span class="pun">(</span></code></li><li class="L3"><code class="language-python"><span class="pln"> input</span><span class="pun">=</span><span class="pln">img</span><span class="pun">,</span></code></li><li class="L4"><code class="language-python"><span class="pln"> filter_size</span><span class="pun">=</span><span class="lit">5</span><span class="pun">,</span></code></li><li class="L5"><code class="language-python"><span class="pln"> num_filters</span><span class="pun">=</span><span class="lit">20</span><span class="pun">,</span></code></li><li class="L6"><code class="language-python"><span class="pln"> num_channel</span><span class="pun">=</span><span class="lit">1</span><span class="pun">,</span></code></li><li class="L7"><code class="language-python"><span class="pln"> pool_size</span><span class="pun">=</span><span class="lit">2</span><span class="pun">,</span></code></li><li class="L8"><code class="language-python"><span class="pln"> pool_stride</span><span class="pun">=</span><span class="lit">2</span><span class="pun">,</span></code></li><li class="L9"><code class="language-python"><span class="pln"> act</span><span class="pun">=</span><span class="typ">TanhActivation</span><span class="pun">())</span></code></li><li class="L0"><code class="language-python"><span class="pln"> </span><span class="com"># 第二个卷积-池化层</span></code></li><li class="L1"><code class="language-python"><span class="pln"> conv_pool_2 </span><span class="pun">=</span><span class="pln"> simple_img_conv_pool</span><span class="pun">(</span></code></li><li class="L2"><code class="language-python"><span class="pln"> input</span><span class="pun">=</span><span class="pln">conv_pool_1</span><span class="pun">,</span></code></li><li class="L3"><code class="language-python"><span class="pln"> filter_size</span><span class="pun">=</span><span class="lit">5</span><span class="pun">,</span></code></li><li class="L4"><code class="language-python"><span class="pln"> num_filters</span><span class="pun">=</span><span class="lit">50</span><span class="pun">,</span></code></li><li class="L5"><code class="language-python"><span class="pln"> num_channel</span><span class="pun">=</span><span class="lit">20</span><span class="pun">,</span></code></li><li class="L6"><code class="language-python"><span class="pln"> pool_size</span><span class="pun">=</span><span class="lit">2</span><span class="pun">,</span></code></li><li class="L7"><code class="language-python"><span class="pln"> pool_stride</span><span class="pun">=</span><span class="lit">2</span><span class="pun">,</span></code></li><li class="L8"><code class="language-python"><span class="pln"> act</span><span class="pun">=</span><span class="typ">TanhActivation</span><span class="pun">())</span></code></li><li class="L9"><code class="language-python"><span class="pln"> </span><span class="com"># 全连接层</span></code></li><li class="L0"><code class="language-python"><span class="pln"> fc1 </span><span class="pun">=</span><span class="pln"> fc_layer</span><span class="pun">(</span><span class="pln">input</span><span class="pun">=</span><span class="pln">conv_pool_2</span><span class="pun">,</span><span class="pln"> size</span><span class="pun">=</span><span class="lit">128</span><span class="pun">,</span><span class="pln"> act</span><span class="pun">=</span><span class="typ">TanhActivation</span><span class="pun">())</span></code></li><li class="L1"><code class="language-python"><span class="pln"> </span><span class="com"># 以softmax为激活函数的全连接输出层,输出层的大小必须为数字的个数10</span></code></li><li class="L2"><code class="language-python"><span class="pln"> predict </span><span class="pun">=</span><span class="pln"> fc_layer</span><span class="pun">(</span><span class="pln">input</span><span class="pun">=</span><span class="pln">fc1</span><span class="pun">,</span><span class="pln"> size</span><span class="pun">=</span><span class="lit">10</span><span class="pun">,</span><span class="pln"> act</span><span class="pun">=</span><span class="typ">SoftmaxActivation</span><span class="pun">())</span></code></li><li class="L3"><code class="language-python"><span class="pln"> </span><span class="kwd">return</span><span class="pln"> predict</span></code></li></ol></pre><div class="md-section-divider"></div><div class="md-section-divider"></div><div class="md-section-divider"></div><pre class="prettyprint linenums prettyprinted" data-anchor-id="ad6q"><ol class="linenums"><li class="L0"><code class="language-bash"><span class="pln">config</span><span class="pun">=</span><span class="pln">mnist_model</span><span class="pun">.</span><span class="pln">py </span><span class="com"># 在mnist_model.py中可以选择网络</span></code></li><li class="L1"><code class="language-bash"><span class="pln">output</span><span class="pun">=./</span><span class="pln">softmax_mnist_model </span></code></li><li class="L2"><code class="language-bash"><span class="pln">log</span><span class="pun">=</span><span class="pln">softmax_train</span><span class="pun">.</span><span class="pln">log </span></code></li><li class="L3"><code class="language-bash"></code></li><li class="L4"><code class="language-bash"><span class="pln">paddle train \</span></code></li><li class="L5"><code class="language-bash"><span class="pun">--</span><span class="pln">config</span><span class="pun">=</span><span class="pln">$config \ </span><span class="com"># 网络配置的脚本</span></code></li><li class="L6"><code class="language-bash"><span class="pun">--</span><span class="pln">dot_period</span><span class="pun">=</span><span class="lit">10</span><span class="pln"> \ </span><span class="com"># 每训练 `dot_period` 个批次后打印一个 `.`</span></code></li><li class="L7"><code class="language-bash"><span class="pun">--</span><span class="pln">log_period</span><span class="pun">=</span><span class="lit">100</span><span class="pln"> \ </span><span class="com"># 每隔多少batch打印一次日志</span></code></li><li class="L8"><code class="language-bash"><span class="pun">--</span><span class="pln">test_all_data_in_one_period</span><span class="pun">=</span><span class="lit">1</span><span class="pln"> \ </span><span class="com"># 每次测试是否用所有的数据</span></code></li><li class="L9"><code class="language-bash"><span class="pun">--</span><span class="pln">use_gpu</span><span class="pun">=</span><span class="lit">0</span><span class="pln"> \ </span><span class="com"># 是否使用GPU</span></code></li><li class="L0"><code class="language-bash"><span class="pun">--</span><span class="pln">trainer_count</span><span class="pun">=</span><span class="lit">1</span><span class="pln"> \ </span><span class="com"># 使用CPU或GPU的个数</span></code></li><li class="L1"><code class="language-bash"><span class="pun">--</span><span class="pln">num_passes</span><span class="pun">=</span><span class="lit">100</span><span class="pln"> \ </span><span class="com"># 训练进行的轮数(每次训练使用完所有数据为1轮)</span></code></li><li class="L2"><code class="language-bash"><span class="pun">--</span><span class="pln">save_dir</span><span class="pun">=</span><span class="pln">$output \ </span><span class="com"># 模型存储的位置</span></code></li><li class="L3"><code class="language-bash"><span class="lit">2</span><span class="pun">&gt;&amp;</span><span class="lit">1</span><span class="pln"> </span><span class="pun">|</span><span class="pln"> tee $log</span></code></li><li class="L4"><code class="language-bash"></code></li><li class="L5"><code class="language-bash"><span class="pln">python </span><span class="pun">-</span><span class="pln">m paddle</span><span class="pun">.</span><span class="pln">utils</span><span class="pun">.</span><span class="pln">plotcurve </span><span class="pun">-</span><span class="pln">i $log </span><span class="pun">&gt;</span><span class="pln"> plot</span><span class="pun">.</span><span class="pln">png</span></code></li></ol></pre><div class="md-section-divider"></div><pre class="prettyprint linenums prettyprinted" data-anchor-id="76fl"><ol class="linenums"><li class="L0"><code><span class="pln">I0117 </span><span class="lit">12</span><span class="pun">:</span><span class="lit">52</span><span class="pun">:</span><span class="lit">29.628617</span><span class="pln"> </span><span class="lit">4538</span><span class="pln"> </span><span class="typ">TrainerInternal</span><span class="pun">.</span><span class="pln">cpp</span><span class="pun">:</span><span class="lit">165</span><span class="pun">]</span><span class="pln"> </span><span class="typ">Batch</span><span class="pun">=</span><span class="lit">100</span><span class="pln"> samples</span><span class="pun">=</span><span class="lit">12800</span><span class="pln"> </span><span class="typ">AvgCost</span><span class="pun">=</span><span class="lit">2.63996</span><span class="pln"> </span><span class="typ">CurrentCost</span><span class="pun">=</span><span class="lit">2.63996</span><span class="pln"> </span><span class="typ">Eval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.241172</span><span class="pln"> </span><span class="typ">CurrentEval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.241172</span><span class="pln"> </span></code></li><li class="L1"><code><span class="pun">.........</span></code></li><li class="L2"><code><span class="pln">I0117 </span><span class="lit">12</span><span class="pun">:</span><span class="lit">52</span><span class="pun">:</span><span class="lit">29.768741</span><span class="pln"> </span><span class="lit">4538</span><span class="pln"> </span><span class="typ">TrainerInternal</span><span class="pun">.</span><span class="pln">cpp</span><span class="pun">:</span><span class="lit">165</span><span class="pun">]</span><span class="pln"> </span><span class="typ">Batch</span><span class="pun">=</span><span class="lit">200</span><span class="pln"> samples</span><span class="pun">=</span><span class="lit">25600</span><span class="pln"> </span><span class="typ">AvgCost</span><span class="pun">=</span><span class="lit">1.74027</span><span class="pln"> </span><span class="typ">CurrentCost</span><span class="pun">=</span><span class="lit">0.840582</span><span class="pln"> </span><span class="typ">Eval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.185234</span><span class="pln"> </span><span class="typ">CurrentEval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.129297</span><span class="pln"> </span></code></li><li class="L3"><code><span class="pun">.........</span></code></li><li class="L4"><code><span class="pln">I0117 </span><span class="lit">12</span><span class="pun">:</span><span class="lit">52</span><span class="pun">:</span><span class="lit">29.916970</span><span class="pln"> </span><span class="lit">4538</span><span class="pln"> </span><span class="typ">TrainerInternal</span><span class="pun">.</span><span class="pln">cpp</span><span class="pun">:</span><span class="lit">165</span><span class="pun">]</span><span class="pln"> </span><span class="typ">Batch</span><span class="pun">=</span><span class="lit">300</span><span class="pln"> samples</span><span class="pun">=</span><span class="lit">38400</span><span class="pln"> </span><span class="typ">AvgCost</span><span class="pun">=</span><span class="lit">1.42119</span><span class="pln"> </span><span class="typ">CurrentCost</span><span class="pun">=</span><span class="lit">0.783026</span><span class="pln"> </span><span class="typ">Eval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.167786</span><span class="pln"> </span><span class="typ">CurrentEval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.132891</span><span class="pln"> </span></code></li><li class="L5"><code><span class="pun">.........</span></code></li><li class="L6"><code><span class="pln">I0117 </span><span class="lit">12</span><span class="pun">:</span><span class="lit">52</span><span class="pun">:</span><span class="lit">30.061213</span><span class="pln"> </span><span class="lit">4538</span><span class="pln"> </span><span class="typ">TrainerInternal</span><span class="pun">.</span><span class="pln">cpp</span><span class="pun">:</span><span class="lit">165</span><span class="pun">]</span><span class="pln"> </span><span class="typ">Batch</span><span class="pun">=</span><span class="lit">400</span><span class="pln"> samples</span><span class="pun">=</span><span class="lit">51200</span><span class="pln"> </span><span class="typ">AvgCost</span><span class="pun">=</span><span class="lit">1.23965</span><span class="pln"> </span><span class="typ">CurrentCost</span><span class="pun">=</span><span class="lit">0.695054</span><span class="pln"> </span><span class="typ">Eval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.160039</span><span class="pln"> </span><span class="typ">CurrentEval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.136797</span><span class="pln"> </span></code></li><li class="L7"><code><span class="pun">......</span><span class="pln">I0117 </span><span class="lit">12</span><span class="pun">:</span><span class="lit">52</span><span class="pun">:</span><span class="lit">30.223270</span><span class="pln"> </span><span class="lit">4538</span><span class="pln"> </span><span class="typ">TrainerInternal</span><span class="pun">.</span><span class="pln">cpp</span><span class="pun">:</span><span class="lit">181</span><span class="pun">]</span><span class="pln"> </span><span class="typ">Pass</span><span class="pun">=</span><span class="lit">0</span><span class="pln"> </span><span class="typ">Batch</span><span class="pun">=</span><span class="lit">469</span><span class="pln"> samples</span><span class="pun">=</span><span class="lit">60000</span><span class="pln"> </span><span class="typ">AvgCost</span><span class="pun">=</span><span class="lit">1.1628</span><span class="pln"> </span><span class="typ">Eval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.156233</span><span class="pln"> </span></code></li><li class="L8"><code><span class="pln">I0117 </span><span class="lit">12</span><span class="pun">:</span><span class="lit">52</span><span class="pun">:</span><span class="lit">30.366894</span><span class="pln"> </span><span class="lit">4538</span><span class="pln"> </span><span class="typ">Tester</span><span class="pun">.</span><span class="pln">cpp</span><span class="pun">:</span><span class="lit">109</span><span class="pun">]</span><span class="pln"> </span><span class="typ">Test</span><span class="pln"> samples</span><span class="pun">=</span><span class="lit">10000</span><span class="pln"> cost</span><span class="pun">=</span><span class="lit">0.50777</span><span class="pln"> </span><span class="typ">Eval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.0978</span><span class="pln"> </span></code></li></ol></pre><div class="md-section-divider"></div><pre class="prettyprint linenums prettyprinted" data-anchor-id="g57s"><ol class="linenums"><li class="L0"><code class="language-bash"><span class="pln">python plot_cost</span><span class="pun">.</span><span class="pln">py softmax_train</span><span class="pun">.</span><span class="pln">log </span></code></li></ol></pre><div class="md-section-divider"></div><pre class="prettyprint linenums prettyprinted" data-anchor-id="bjm8"><ol class="linenums"><li class="L0"><code class="language-bash"><span class="pln">python evaluate</span><span class="pun">.</span><span class="pln">py softmax_train</span><span class="pun">.</span><span class="pln">log</span></code></li></ol></pre><div class="md-section-divider"></div><p align="center" data-anchor-id="j38j">
</ul><p data-anchor-id="tss0">更详细的介绍请参考<a href="https://en.wikipedia.org/wiki/Activation_function" target="_blank">维基百科激活函数</a></p><div class="md-section-divider"></div><h2 data-anchor-id="xbcu" id="数据准备">数据准备</h2><div class="md-section-divider"></div><h3 data-anchor-id="c26s" id="数据介绍与下载">数据介绍与下载</h3><p data-anchor-id="pim6">执行以下命令,下载<a href="http://yann.lecun.com/exdb/mnist/" target="_blank">MNIST</a>数据库并解压缩,然后将训练集和测试集的地址分别写入train.list和test.list两个文件,供PaddlePaddle读取。</p><pre data-anchor-id="hdhx"><code>./data/get_mnist_data.sh
</code></pre><p data-anchor-id="fbdq">将下载下来的数据进行 <code>gzip</code> 解压,可以在文件夹 <code>data/raw_data</code> 中找到以下文件:</p><table data-anchor-id="tpwt" class="table table-striped-white table-bordered">
<thead>
<tr>
<th>文件名称</th>
<th>说明</th>
</tr>
</thead>
<tbody><tr>
<td>train-images-idx3-ubyte</td>
<td>训练数据图片,60,000条数据</td>
</tr>
<tr>
<td>train-labels-idx1-ubyte</td>
<td>训练数据标签,60,000条数据</td>
</tr>
<tr>
<td>t10k-images-idx3-ubyte</td>
<td>测试数据图片,10,000条数据</td>
</tr>
<tr>
<td>t10k-labels-idx1-ubyte</td>
<td>测试数据标签,10,000条数据</td>
</tr>
</tbody></table><p data-anchor-id="mste">用户可以通过以下脚本随机绘制10张图片(可参考图1):</p><pre data-anchor-id="27s8"><code>./load_data.py
</code></pre><div class="md-section-divider"></div><h3 data-anchor-id="ukke" id="提供数据给paddlepaddle">提供数据给PaddlePaddle</h3><p data-anchor-id="sdul">我们使用python接口传递数据给系统,下面 <code>mnist_provider.py</code>针对MNIST数据给出了完整示例。</p><pre data-anchor-id="x3bc"><code># Define a py data provider
@provider(
input_types={'pixel': dense_vector(28 * 28),
'label': integer_value(10)})
def process(settings, filename): # settings is not used currently.
# 打开图片文件
with open( filename + "-images-idx3-ubyte", "rb") as f:
# 读取开头的四个参数,magic代表数据的格式,n代表数据的总量,rows和cols分别代表行数和列数
magic, n, rows, cols = struct.upack("&gt;IIII", f.read(16))
# 以无符号字节为单位一个一个的读取数据
images = np.fromfile(
f, 'ubyte',
count=n * rows * cols).reshape(n, rows, cols).astype('float32')
# 将0~255的数据归一化到[-1,1]的区间
images = images / 255.0 * 2.0 - 1.0
# 打开标签文件
with open( filename + "-labels-idx1-ubyte", "rb") as l:
# 读取开头的两个参数
magic, n = struct.upack("&gt;II", l.read(8))
# 以无符号字节为单位一个一个的读取数据
labels = np.fromfile(l, 'ubyte', count=n).astype("int")
for i in xrange(n):
yield {"pixel": images[i, :], 'label': labels[i]}
</code></pre><div class="md-section-divider"></div><h2 data-anchor-id="oilx" id="模型配置说明">模型配置说明</h2><div class="md-section-divider"></div><h3 data-anchor-id="2lby" id="数据定义">数据定义</h3><p data-anchor-id="o0hn">在模型配置中,定义通过 <code>define_py_data_sources2</code> 函数从 <code>dataprovider</code> 中读入数据。如果该配置用于预测,则不需要数据定义部分。</p><pre data-anchor-id="tzmf"><code>if not is_predict:
data_dir = './data/'
define_py_data_sources2(
train_list=data_dir + 'train.list',
test_list=data_dir + 'test.list',
module='mnist_provider',
obj='process')
</code></pre><div class="md-section-divider"></div><h3 data-anchor-id="hhqm" id="算法配置">算法配置</h3><p data-anchor-id="y5hw">指定训练相关的参数。</p><ul data-anchor-id="ytth">
<li>batch_size: 表示神经网络每次训练使用的数据为128条。</li>
<li>训练速度(learning_rate): 迭代的速度,与网络的训练收敛速度有关系。</li>
<li>训练方法(learning_method): 代表训练过程在更新权重时采用动量优化器 <code>MomentumOptimizer</code> ,其中参数0.9代表动量优化每次保持前一次速度的0.9倍。</li>
<li><p>正则化(regularization): 是防止网络过拟合的一种手段,此处采用L2正则化。</p>
<pre class="prettyprint linenums prettyprinted"><ol class="linenums"><li class="L0"><code class="language-python"><span class="pln">settings</span><span class="pun">(</span></code></li><li class="L1"><code class="language-python"><span class="pln"> batch_size</span><span class="pun">=</span><span class="lit">128</span><span class="pun">,</span></code></li><li class="L2"><code class="language-python"><span class="pln"> learning_rate</span><span class="pun">=</span><span class="lit">0.1</span><span class="pln"> </span><span class="pun">/</span><span class="pln"> </span><span class="lit">128.0</span><span class="pun">,</span></code></li><li class="L3"><code class="language-python"><span class="pln"> learning_method</span><span class="pun">=</span><span class="typ">MomentumOptimizer</span><span class="pun">(</span><span class="lit">0.9</span><span class="pun">),</span></code></li><li class="L4"><code class="language-python"><span class="pln"> regularization</span><span class="pun">=</span><span class="pln">L2Regularization</span><span class="pun">(</span><span class="lit">0.0005</span><span class="pln"> </span><span class="pun">*</span><span class="pln"> </span><span class="lit">128</span><span class="pun">))</span></code></li></ol></pre></li>
</ul><div class="md-section-divider"></div><h3 data-anchor-id="23bp" id="模型结构">模型结构</h3><div class="md-section-divider"></div><h4 data-anchor-id="zdyh" id="整体结构">整体结构</h4><p data-anchor-id="6p9n">首先通过<code>data_layer</code>调用来获取数据,然后调用分类器(这里我们提供了三个不同的分类器)得到分类结果。训练时,对该结果计算其损失函数,分类问题常常选择交叉熵损失函数;而预测时直接输出该结果即可。</p><pre data-anchor-id="wjxj"><code>data_size = 1 * 28 * 28
label_size = 10
img = data_layer(name='pixel', size=data_size)
predict = softmax_regression(img) # Softmax回归
#predict = multilayer_perceptron(img) #多层感知器
#predict = convolutional_neural_network(img) #LeNet5卷积神经网络
if not is_predict:
lbl = data_layer(name="label", size=label_size)
inputs(img, lbl)
outputs(classification_cost(input=predict, label=lbl))
else:
outputs(predict)
</code></pre><div class="md-section-divider"></div><h4 data-anchor-id="iue9" id="softmax回归">Softmax回归</h4><p data-anchor-id="p6mv">只通过一层简单的以softmax为激活函数的全连接层,就可以得到分类的结果。</p><pre data-anchor-id="nop8"><code>def softmax_regression(img):
predict = fc_layer(input=img, size=10, act=SoftmaxActivation())
return predict
</code></pre><div class="md-section-divider"></div><h4 data-anchor-id="obqh" id="多层感知器">多层感知器</h4><p data-anchor-id="z8sk">下面代码实现了一个含有两个隐藏层(即全连接层)的多层感知器。其中两个隐藏层的激活函数均采用ReLU,输出层的激活函数用Softmax。</p><pre data-anchor-id="wyha"><code>def multilayer_perceptron(img):
# 第一个全连接层,激活函数为ReLU
hidden1 = fc_layer(input=img, size=128, act=ReluActivation())
# 第二个全连接层,激活函数为ReLU
hidden2 = fc_layer(input=hidden1, size=64, act=ReluActivation())
# 以softmax为激活函数的全连接输出层,输出层的大小必须为数字的个数10
predict = fc_layer(input=hidden2, size=10, act=SoftmaxActivation())
return predict
</code></pre><div class="md-section-divider"></div><h4 data-anchor-id="eiz3" id="卷积神经网络lenet-5">卷积神经网络LeNet-5</h4><p data-anchor-id="w1da">以下为LeNet-5的网络结构:输入的二维图像,首先经过两次卷积层到池化层,再经过全连接层,最后使用以softmax为激活函数的全连接层作为输出层。</p><pre data-anchor-id="8abp"><code>def convolutional_neural_network(img):
# 第一个卷积-池化层
conv_pool_1 = simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
num_channel=1,
pool_size=2,
pool_stride=2,
act=TanhActivation())
# 第二个卷积-池化层
conv_pool_2 = simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
num_channel=20,
pool_size=2,
pool_stride=2,
act=TanhActivation())
# 全连接层
fc1 = fc_layer(input=conv_pool_2, size=128, act=TanhActivation())
# 以softmax为激活函数的全连接输出层,输出层的大小必须为数字的个数10
predict = fc_layer(input=fc1, size=10, act=SoftmaxActivation())
return predict
</code></pre><div class="md-section-divider"></div><h2 data-anchor-id="lotz" id="训练模型">训练模型</h2><div class="md-section-divider"></div><h3 data-anchor-id="125p" id="训练命令及日志">训练命令及日志</h3><ol data-anchor-id="gks8">
<li><p>通过配置训练脚本 <code>train.sh</code> 来执行训练过程:</p>
<pre class="prettyprint linenums prettyprinted"><ol class="linenums"><li class="L0"><code class="language-bash"><span class="pln">config</span><span class="pun">=</span><span class="pln">mnist_model</span><span class="pun">.</span><span class="pln">py </span><span class="com"># 在mnist_model.py中可以选择网络</span></code></li><li class="L1"><code class="language-bash"><span class="pln">output</span><span class="pun">=./</span><span class="pln">softmax_mnist_model </span></code></li><li class="L2"><code class="language-bash"><span class="pln">log</span><span class="pun">=</span><span class="pln">softmax_train</span><span class="pun">.</span><span class="pln">log </span></code></li><li class="L3"><code class="language-bash"></code></li><li class="L4"><code class="language-bash"><span class="pln">paddle train \</span></code></li><li class="L5"><code class="language-bash"><span class="pun">--</span><span class="pln">config</span><span class="pun">=</span><span class="pln">$config \ </span><span class="com"># 网络配置的脚本</span></code></li><li class="L6"><code class="language-bash"><span class="pln"> </span><span class="pun">--</span><span class="pln">dot_period</span><span class="pun">=</span><span class="lit">10</span><span class="pln"> \ </span><span class="com"># 每训练 `dot_period` 个批次后打印一个 `.`</span></code></li><li class="L7"><code class="language-bash"><span class="pln"> </span><span class="pun">--</span><span class="pln">log_period</span><span class="pun">=</span><span class="lit">100</span><span class="pln"> \ </span><span class="com"># 每隔多少batch打印一次日志</span></code></li><li class="L8"><code class="language-bash"><span class="pln"> </span><span class="pun">--</span><span class="pln">test_all_data_in_one_period</span><span class="pun">=</span><span class="lit">1</span><span class="pln"> \ </span><span class="com"># 每次测试是否用所有的数据</span></code></li><li class="L9"><code class="language-bash"><span class="pln"> </span><span class="pun">--</span><span class="pln">use_gpu</span><span class="pun">=</span><span class="lit">0</span><span class="pln"> \ </span><span class="com"># 是否使用GPU</span></code></li><li class="L0"><code class="language-bash"><span class="pln"> </span><span class="pun">--</span><span class="pln">trainer_count</span><span class="pun">=</span><span class="lit">1</span><span class="pln"> \ </span><span class="com"># 使用CPU或GPU的个数</span></code></li><li class="L1"><code class="language-bash"><span class="pln"> </span><span class="pun">--</span><span class="pln">num_passes</span><span class="pun">=</span><span class="lit">100</span><span class="pln"> \ </span><span class="com"># 训练进行的轮数(每次训练使用完所有数据为1轮)</span></code></li><li class="L2"><code class="language-bash"><span class="pln"> </span><span class="pun">--</span><span class="pln">save_dir</span><span class="pun">=</span><span class="pln">$output \ </span><span class="com"># 模型存储的位置</span></code></li><li class="L3"><code class="language-bash"><span class="lit">2</span><span class="pun">&gt;&amp;</span><span class="lit">1</span><span class="pln"> </span><span class="pun">|</span><span class="pln"> tee $log</span></code></li><li class="L4"><code class="language-bash"></code></li><li class="L5"><code class="language-bash"><span class="pln">python </span><span class="pun">-</span><span class="pln">m paddle</span><span class="pun">.</span><span class="pln">utils</span><span class="pun">.</span><span class="pln">plotcurve </span><span class="pun">-</span><span class="pln">i $log </span><span class="pun">&gt;</span><span class="pln"> plot</span><span class="pun">.</span><span class="pln">png</span></code></li></ol></pre>
<p>配置好参数之后,执行脚本 <code>./train.sh</code> 训练日志类似如下所示:</p>
<pre class="prettyprint linenums prettyprinted"><ol class="linenums"><li class="L0"><code><span class="pln">I0117 </span><span class="lit">12</span><span class="pun">:</span><span class="lit">52</span><span class="pun">:</span><span class="lit">29.628617</span><span class="pln"> </span><span class="lit">4538</span><span class="pln"> </span><span class="typ">TrainerInternal</span><span class="pun">.</span><span class="pln">cpp</span><span class="pun">:</span><span class="lit">165</span><span class="pun">]</span><span class="pln"> </span><span class="typ">Batch</span><span class="pun">=</span><span class="lit">100</span><span class="pln"> samples</span><span class="pun">=</span><span class="lit">12800</span><span class="pln"> </span><span class="typ">AvgCost</span><span class="pun">=</span><span class="lit">2.63996</span><span class="pln"> </span><span class="typ">CurrentCost</span><span class="pun">=</span><span class="lit">2.63996</span><span class="pln"> </span><span class="typ">Eval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.241172</span><span class="pln"> </span><span class="typ">CurrentEval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.241172</span><span class="pln"> </span></code></li><li class="L1"><code><span class="pun">.........</span></code></li><li class="L2"><code><span class="pln">I0117 </span><span class="lit">12</span><span class="pun">:</span><span class="lit">52</span><span class="pun">:</span><span class="lit">29.768741</span><span class="pln"> </span><span class="lit">4538</span><span class="pln"> </span><span class="typ">TrainerInternal</span><span class="pun">.</span><span class="pln">cpp</span><span class="pun">:</span><span class="lit">165</span><span class="pun">]</span><span class="pln"> </span><span class="typ">Batch</span><span class="pun">=</span><span class="lit">200</span><span class="pln"> samples</span><span class="pun">=</span><span class="lit">25600</span><span class="pln"> </span><span class="typ">AvgCost</span><span class="pun">=</span><span class="lit">1.74027</span><span class="pln"> </span><span class="typ">CurrentCost</span><span class="pun">=</span><span class="lit">0.840582</span><span class="pln"> </span><span class="typ">Eval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.185234</span><span class="pln"> </span><span class="typ">CurrentEval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.129297</span><span class="pln"> </span></code></li><li class="L3"><code><span class="pun">.........</span></code></li><li class="L4"><code><span class="pln">I0117 </span><span class="lit">12</span><span class="pun">:</span><span class="lit">52</span><span class="pun">:</span><span class="lit">29.916970</span><span class="pln"> </span><span class="lit">4538</span><span class="pln"> </span><span class="typ">TrainerInternal</span><span class="pun">.</span><span class="pln">cpp</span><span class="pun">:</span><span class="lit">165</span><span class="pun">]</span><span class="pln"> </span><span class="typ">Batch</span><span class="pun">=</span><span class="lit">300</span><span class="pln"> samples</span><span class="pun">=</span><span class="lit">38400</span><span class="pln"> </span><span class="typ">AvgCost</span><span class="pun">=</span><span class="lit">1.42119</span><span class="pln"> </span><span class="typ">CurrentCost</span><span class="pun">=</span><span class="lit">0.783026</span><span class="pln"> </span><span class="typ">Eval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.167786</span><span class="pln"> </span><span class="typ">CurrentEval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.132891</span><span class="pln"> </span></code></li><li class="L5"><code><span class="pun">.........</span></code></li><li class="L6"><code><span class="pln">I0117 </span><span class="lit">12</span><span class="pun">:</span><span class="lit">52</span><span class="pun">:</span><span class="lit">30.061213</span><span class="pln"> </span><span class="lit">4538</span><span class="pln"> </span><span class="typ">TrainerInternal</span><span class="pun">.</span><span class="pln">cpp</span><span class="pun">:</span><span class="lit">165</span><span class="pun">]</span><span class="pln"> </span><span class="typ">Batch</span><span class="pun">=</span><span class="lit">400</span><span class="pln"> samples</span><span class="pun">=</span><span class="lit">51200</span><span class="pln"> </span><span class="typ">AvgCost</span><span class="pun">=</span><span class="lit">1.23965</span><span class="pln"> </span><span class="typ">CurrentCost</span><span class="pun">=</span><span class="lit">0.695054</span><span class="pln"> </span><span class="typ">Eval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.160039</span><span class="pln"> </span><span class="typ">CurrentEval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.136797</span><span class="pln"> </span></code></li><li class="L7"><code><span class="pun">......</span><span class="pln">I0117 </span><span class="lit">12</span><span class="pun">:</span><span class="lit">52</span><span class="pun">:</span><span class="lit">30.223270</span><span class="pln"> </span><span class="lit">4538</span><span class="pln"> </span><span class="typ">TrainerInternal</span><span class="pun">.</span><span class="pln">cpp</span><span class="pun">:</span><span class="lit">181</span><span class="pun">]</span><span class="pln"> </span><span class="typ">Pass</span><span class="pun">=</span><span class="lit">0</span><span class="pln"> </span><span class="typ">Batch</span><span class="pun">=</span><span class="lit">469</span><span class="pln"> samples</span><span class="pun">=</span><span class="lit">60000</span><span class="pln"> </span><span class="typ">AvgCost</span><span class="pun">=</span><span class="lit">1.1628</span><span class="pln"> </span><span class="typ">Eval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.156233</span><span class="pln"> </span></code></li><li class="L8"><code><span class="pln">I0117 </span><span class="lit">12</span><span class="pun">:</span><span class="lit">52</span><span class="pun">:</span><span class="lit">30.366894</span><span class="pln"> </span><span class="lit">4538</span><span class="pln"> </span><span class="typ">Tester</span><span class="pun">.</span><span class="pln">cpp</span><span class="pun">:</span><span class="lit">109</span><span class="pun">]</span><span class="pln"> </span><span class="typ">Test</span><span class="pln"> samples</span><span class="pun">=</span><span class="lit">10000</span><span class="pln"> cost</span><span class="pun">=</span><span class="lit">0.50777</span><span class="pln"> </span><span class="typ">Eval</span><span class="pun">:</span><span class="pln"> classification_error_evaluator</span><span class="pun">=</span><span class="lit">0.0978</span><span class="pln"> </span></code></li></ol></pre></li>
<li><p>用脚本 <code>plot_cost.py</code> 可以画出训练过程中的误差变化曲线:</p>
<pre class="prettyprint linenums prettyprinted"><ol class="linenums"><li class="L0"><code class="language-bash"><span class="pln">python plot_cost</span><span class="pun">.</span><span class="pln">py softmax_train</span><span class="pun">.</span><span class="pln">log </span></code></li></ol></pre></li>
<li><p>用脚本 <code>evaluate.py</code> 可以选出最佳训练的模型:</p>
<pre class="prettyprint linenums prettyprinted"><ol class="linenums"><li class="L0"><code class="language-bash"><span class="pln">python evaluate</span><span class="pun">.</span><span class="pln">py softmax_train</span><span class="pun">.</span><span class="pln">log</span></code></li></ol></pre></li>
</ol><div class="md-section-divider"></div><h3 data-anchor-id="5qjw" id="softmax回归的训练结果">softmax回归的训练结果</h3><p align="center" data-anchor-id="j38j">
<img src="https://raw.githubusercontent.com/PaddlePaddle/book/develop/recognize_digits/image/softmax_train_log.png" width="400"><br>
图7. softmax回归的误差曲线图<br>
</p><p data-anchor-id="3c5t">评估模型结果如下:</p><div class="md-section-divider"></div><pre class="prettyprint linenums prettyprinted" data-anchor-id="rwc2"><ol class="linenums"><li class="L0"><code class="language-text"><span class="typ">Best</span><span class="pln"> </span><span class="kwd">pass</span><span class="pln"> </span><span class="kwd">is</span><span class="pln"> </span><span class="lit">00013</span><span class="pun">,</span><span class="pln"> testing </span><span class="typ">Avgcost</span><span class="pln"> </span><span class="kwd">is</span><span class="pln"> </span><span class="lit">0.484447</span></code></li><li class="L1"><code class="language-text"><span class="typ">The</span><span class="pln"> classification accuracy </span><span class="kwd">is</span><span class="pln"> </span><span class="lit">90.01</span><span class="pun">%</span></code></li></ol></pre><div class="md-section-divider"></div><p align="center" data-anchor-id="p1pu">
</p><p data-anchor-id="2jhq">评估模型结果如下:</p><div class="md-section-divider"></div><pre class="prettyprint linenums prettyprinted" data-anchor-id="45oz"><ol class="linenums"><li class="L0"><code class="language-text"><span class="typ">Best</span><span class="pln"> </span><span class="kwd">pass</span><span class="pln"> </span><span class="kwd">is</span><span class="pln"> </span><span class="lit">00013</span><span class="pun">,</span><span class="pln"> testing </span><span class="typ">Avgcost</span><span class="pln"> </span><span class="kwd">is</span><span class="pln"> </span><span class="lit">0.484447</span></code></li><li class="L1"><code class="language-text"><span class="typ">The</span><span class="pln"> classification accuracy </span><span class="kwd">is</span><span class="pln"> </span><span class="lit">90.01</span><span class="pun">%</span></code></li></ol></pre><div class="md-section-divider"></div><p align="center" data-anchor-id="p1pu">
<img src="https://raw.githubusercontent.com/PaddlePaddle/book/develop/recognize_digits/image/mlp_train_log.png" width="400"><br>
图8. 多层感知器的误差曲线图
</p><p data-anchor-id="eza7">评估模型结果如下:</p><div class="md-section-divider"></div><pre class="prettyprint linenums prettyprinted" data-anchor-id="hk46"><ol class="linenums"><li class="L0"><code class="language-text"><span class="typ">Best</span><span class="pln"> </span><span class="kwd">pass</span><span class="pln"> </span><span class="kwd">is</span><span class="pln"> </span><span class="lit">00085</span><span class="pun">,</span><span class="pln"> testing </span><span class="typ">Avgcost</span><span class="pln"> </span><span class="kwd">is</span><span class="pln"> </span><span class="lit">0.164746</span></code></li><li class="L1"><code class="language-text"><span class="typ">The</span><span class="pln"> classification accuracy </span><span class="kwd">is</span><span class="pln"> </span><span class="lit">94.95</span><span class="pun">%</span></code></li></ol></pre><div class="md-section-divider"></div><p align="center" data-anchor-id="4hu1">
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册