diff --git a/tutorials/notebook/linear_regression.ipynb b/tutorials/notebook/linear_regression.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..026e1dfd4ddd698bba9fd8c14900ea4027738487 --- /dev/null +++ b/tutorials/notebook/linear_regression.ipynb @@ -0,0 +1,711 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##
使用MindSpore实现简单线性函数拟合" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 概述" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "回归问题算法通常是利用一系列属性来预测一个值,预测的值是连续的。例如给出一套房子的一些特征数据,如面积、卧室数等等来预测房价,利用最近一周的气温变化和卫星云图来预测未来的气温情况等。如果一套房子实际价格为500万元,通过回归分析的预测值为499万元,则认为这是一个比较好的回归分析。在机器学习问题中,常见的回归分析有线性回归、多项式回归、逻辑回归等。本例子介绍线性回归算法,并通过MindSpore进行线性回归AI训练体验。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "主要流程如下:\n", + "\n", + "1. 生成数据集\n", + "2. 定义前向传播网络\n", + "3. 定义反向传播网络\n", + "4. 定义线性拟合过程的可视化函数\n", + "5. 执行训练" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 生成数据集" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 定义数据集生成函数\n", + "\n", + "`get_data`用于生成训练数据集和测试数据集。由于拟合的是线性数据,假定要拟合的目标函数为:$y=2x+3$,那么我们需要的训练数据集应随机分布于函数周边,这里采用了`y=2x+3+noise`的方式生成,其中`noise`为遵循标准正态分布规律的随机数值。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import mindspore as ms\n", + "from mindspore import Tensor\n", + "from mindspore import context\n", + "from mindspore.train import Model\n", + "\n", + "context.set_context(mode=context.PYNATIVE_MODE, device_target=\"GPU\")\n", + " \n", + "def get_data(num,w=2.0, b=3.0):\n", + " np_x = np.ones([num, 1])\n", + " np_y = np.ones([num, 1])\n", + " for i in range(num):\n", + " x = np.random.uniform(-10.0, 10.0)\n", + " np_x[i] = x\n", + " noise = np.random.normal(0, 1)\n", + " y = x * w + b + noise\n", + " np_y[i] = y\n", + " return Tensor(np_x,ms.float32), Tensor(np_y,ms.float32)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "对于数据生成函数我们将有以下两个作用。\n", + "\n", + "1. 生成训练数据,对模型函数进行训练。\n", + "2. 生成验证数据,在训练结束后,对模型函数进行精度验证。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 生成测试数据" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "使用数据生成函数`get_data`随机生成50组验证数据,并可视化展示。" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAEICAYAAAC6fYRZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAX5ElEQVR4nO3dfZBcVZ3G8edJEFaUVTSTEEhiwALWaNWCdkVejARBBdY1YIkGt3YpdY2olC+lVYtSpZTW1uouirWlC4aSFbeUlxWQLIK8SSBbvMgEA4QNLAFxGRInk6ASXwprMr/94942neZ2pnv63u7bt7+fqq7pvvdOn1N3ep45c+455zoiBACopln9rgAAoDiEPABUGCEPABVGyANAhRHyAFBhhDwAVBghDwAVRsgDKdtrbf99h9+z2HbY3qeoegHdIOQxkGw/ZfsPtn/b8PhGv+u1N7aX2x7rdz0wXGh9YJD9dUTc1u9KAGVGSx6VYXs/27+2/bqGbSNpi3+u7QNt32B7wvav0ucLOixjtu0LbW+3/aSkv2ra/37bm2zvtP2k7Q+n218i6SZJBzf853Gw7aW270nrvdX2N2zvm8PpACQR8qiQiHhe0rWSzmrY/B5Jd0bENiWf93+X9CpJiyT9QVKnXTwfkvQOSUdLqkl6d9P+ben+P5f0fkkX2X59RPxO0qmStkTES9PHFkm7JH1K0hxJx0o6SdJHO6wT0BIhj0H2w7QFXH98SNL3tWfIvy/dpojYERHXRMTvI2KnpH+UdEKHZb5H0tcj4umIeFbSPzXujIgfRcQTkbhT0i2SlrV6s4hYHxH3RsRkRDwl6VszqBPQEn3yGGSnN/fJ254l6cW23yjpl5KOknRdum9/SRdJOkXSgem3HGB7dkTsarPMgyU93fD6F03lnyrpC5KOUNKI2l/Sw63ezPYRkr6m5L+C/ZX8Tq5vsy7AtGjJo1IiYkrS1Upa8++TdEPaapekT0s6UtIbI+LPJb053e4OitgqaWHD60X1J7b3k3SNpAslzYuIl0u6seH9s9b1vljSo5IOT+v0uQ7rA+wVIY8q+r6k90r6m/R53QFK+uF/bfsVSlrcnbpa0sdtL7B9oKTzGvbtK2k/SROSJtNW/dsa9o9LeqXtlzXV6TlJv7X9F5I+MoM6AS0R8hhk/9U0Tv46SYqI+yT9TknXyk0Nx39d0oslbZd0r6Qfz6DMSyXdLOlBSQ8oudCrtNydkj6u5A/Br5T8J7GmYf+jkq6Q9GR6DeFgSZ9Jj9uZvvdVM6gT0JK5MxQAVBcteQCoMEIeaGL7kqZuoPrjkn7XDegU3TUAUGGlGic/Z86cWLx4cb+rAQADZf369dsjYiRrX6lCfvHixRodHe13NQBgoNj+Rat99MkDQIUR8gBQYYQ8AFQYIQ8AFUbIA0CFEfIAUGGEPAD029SUND4uFTA5lZAHgH6ampJOPFFasEBavjx5nSNCHgD6aWJCuvtuaXIy+ToxkevbE/IA0E9z50rHHSfts0/yde7cXN++VMsaAMDQsaU77kha8HPnJq9zRMgDQL/NmiXNm1fMW3f7BrYX2r7D9ibbj9j+RLr9FbZvtf14+vXA7qsLAOhEHn3yk5I+HRGvkXSMpI/ZXqLkBse3R8Thkm7Xnjc8BgD0QNchHxFbI+KB9PlOSZskHSJphaTL08Mul3R6t2UBADqT6+ga24slHS3pPknzImKrlPwhkJR5ydj2Ktujtkcnch46BADDLreQt/1SSddI+mREPNfu90XE6oioRURtZCTzxiYAgBnKJeRtv0hJwH8vIq5NN4/bnp/uny9pWx5lAQDal8foGkv6tqRNEfG1hl1rJJ2dPj9b0vXdlgUA6Ewe4+SPl/S3kh62vSHd9jlJX5Z0te0PSvo/SWfmUBYAoANdh3xE/LekVlO0Tur2/QEAM8faNQBQYYQ8ANS1u657geu/542QBwApCe7ly6VDDpFOOKH1uu4Fr/+eN0IeAKSkZb5unbRrV/J1fDz7uILXf88bIQ8AUrLEb32Z38bnzQpe/z1vhDwASMlSv8uWJeG9bFnrpX/r67+PjUlr1+a+/nveWE8eAKTObt5R4PrveSPkAaBugMK7XXTXAECzARoiOR1CHgAaDdgQyekQ8gDQaMCGSE6HkAeARgM2RHI6XHgFMBymptobOdPJKJsBQEseQPV12s9eH2Uz4AEvEfIABkkno14aj61YP3snCHkAg6GT1njzsXPmVKqfvRP0yQMYDFmt8VYTl5qP3b69Uv3sncjrRt6X2d5me2PDtgtsP2N7Q/o4LY+yAAypTka9zJkj1WrS7Nm7j61QP3sn8mrJf0fSNyR9t2n7RRFxYU5lABhm7Y56mZqS3vIWaXRUWrpU+slPhi7YG+XSko+IuyQ9m8d7AUBL7bTGG7tq7r8/6aoZYkVfeD3X9kNpd86BWQfYXmV71PboxBBd8QZQkIpNZupWkSF/saRXSzpK0lZJX806KCJWR0QtImojIyMFVgfAUBiw9d6LVljIR8R4ROyKiClJl0paWlRZALCHIb3ImqWwkLc9v+HlGZI2tjoWAFCMXEbX2L5C0nJJc2yPSfqCpOW2j5IUkp6S9OE8ygIAtC+XkI+IszI2fzuP9wYAzBzLGgBAhRHyAFBhhDwAVBghDwAVRsgDQIUR8gDKoZMbgqBthDyA/uv09nxoGyEPoP+G+PZ8RSPkAfQfK0cWhtv/Aei/dm8Igo4R8gDKob5yJHJFdw0AVBghDwAVRsgDQIUR8gBQYYQ8gGzMQK0EQh7ACzEDtTJyCXnbl9neZntjw7ZX2L7V9uPp1wPzKAtADzADtTLyasl/R9IpTdvOk3R7RBwu6fb0NYBBwAzUysgl5CPiLknPNm1eIeny9Pnlkk7PoywAOWvse68/l5IZqGNj0tq1zEAdYEX2yc+LiK2SlH7NbArYXmV71PboBP8SAr3V2Pd+wgl79sNLyQxUAn6g9f3Ca0SsjohaRNRGRkb6XR1guDT3vdMPXzlFhvy47fmSlH7dVmBZAGaiue+9k354hlgOhCIXKFsj6WxJX06/Xl9gWQBmonn1x4j2VoKsd/PcfXfyB+GOO5IFxlA6eQ2hvELSPZKOtD1m+4NKwv2tth+X9Nb0NYAyaGyF11d/tPd8vjcMsRwYubTkI+KsFrtOyuP9AeQoj1Z4vZun/h4MsSwt1pMHqmhqqnW3S1YrvNN13LnJx8CgEw2omlZLEtS7aEZG8pno1G7XDvqKljxQNVkt9ZGRPbtobr9d2rGDVvgQoCUPVE3WkgTNwb9jB63wIUFLHqiaen/5+PjuEM/jQune+vlRWrTkgapauVJauDDpl4/obi0alh4eWIQ8UEXj49K6dUn3zLp1yetuLpQyLn5gEfJAFdm7lxuI6L57haWHBxYhD1RF4yzWefOkZcuk2bOTr52Og29W7+dn6eGBQ8gDVdDcZx6RhPEzz0h33plPKDMufiAR8kAVZPWZE8oQIQ9UA33maIFx8kAVsJYMWqAlD5RZJzfmoHsGGQh5oKyYgIQcEPJAWTEBCTkg5IGy4mIqclD4hVfbT0naKWmXpMmIqBVdJjBwshb/4mIqctCrlvyJEXEUAQ9k2FvfOxdT0SW6a4A8dTIapo6+dxSoFyEfkm6xvd72qh6UB/THTEfD0PeOAvViMtTxEbHF9lxJt9p+NCLuqu9Mg3+VJC1atKgH1QEKMtMbZNP3jgIV3pKPiC3p122SrpO0tGn/6oioRURtZGSk6OoAxemmRU7fOwpSaEve9kskzYqInenzt0n6YpFlAn1DixwlVHR3zTxJ1zn5sO8j6fsR8eOCywT6p94iB0qi0JCPiCcl/WWRZQAAWmMIJQBUGCEPABVGyANAhRHyAFBhhDwAVBghDwAVRsgDQIUR8gBQYYQ8AFQYIQ8AFUbIA0CFEfIYPjO5exMwoAh5DI+pKWnr1uSuTZ3evQkYUL24MxTQf/Vb89Xv3CR1dvcmYEAR8hgOjbfms5N137mfKoYA3TUYDo235lu2TBobk9au5e5NqDxa8hgO3JoPQ4qQx/Dg1nwYQoV319g+xfZjtjfbPq/o8gAAuxUa8rZnS/qmpFMlLZF0lu0lRZYJANit6Jb8UkmbI+LJiPijpCslrSi4TABAquiQP0TS0w2vx9Jtf2J7le1R26MTExMFVwcAhkvRIZ81hGGPueQRsToiahFRGxkZKbg6GAgsOwDkpuiQH5O0sOH1AklbCi4Tg6w+M5VlB4BcFB3y90s63PahtveVtFLSmoLLxCBrnJlaX3YAwIwVGvIRMSnpXEk3S9ok6eqIeKTIMjHgGmem1pcdoPsGmLHCJ0NFxI2Sbiy6HFRE88zUiN0Lix13XLJvFqtxAO3itwXlU5+ZatN9A3SJkEe5ZXXfAGgba9eg3FhYDOgKIY/yY2ExYMborgGACiPkAaDCCHnkh/HsQOkQ8shH43IEJ5wgbd1K2AMlQMgjH43j2detkxYtYu0ZoAQIeeSjPp599uxkmCOTl4BSIOSRj/p49rExadkyJi8BJcE4eeRn1izpoIOYvASUCC155K9x7ZlOMDoHyB0hj3LgZiFAIQh5JJpb0Xm2qtt5L1abBApByOOFrejJyfxa1e220FltEiiEo0T9n7VaLUZHR/tdjeEzPp6E8ORkErI/+5l09NG7X4+NzXyBsOb33tt7TU1xwRaYAdvrI6KWta+wlrztC2w/Y3tD+jitqLLQpeZW9JIl+bWqO2mhz/SCLYCWih5CeVFEXFhwGehW1prteQ2DZD14oK8YJ49E85rtea7hznrwQN8UfeH1XNsP2b7M9oFZB9heZXvU9ugEIyoGC+PagdLrKuRt32Z7Y8ZjhaSLJb1a0lGStkr6atZ7RMTqiKhFRG1kZKSb6qCXGNcODISuumsi4uR2jrN9qaQbuikLJZM1rp0uGaB0ihxdM7/h5RmSNhZVFrqQ1eXSTjcM49qBgVBkn/w/237Y9kOSTpT0qQLLwkxkdbm02w3TuOrk2rWMmgFKislQwyxropLU/uQlAKXQl8lQGABZXS50wwCVwjj5YdZqohKTl4DKoCVfVe2OYc9aSoDlBYDKIOSriDHsAFKEfBXtbW12ZqkCQ4WQr6JWF09p4QNDhwuvVdTqgiqzVIGhQ0u+qrIunjI8Ehg6tOSHCWu7A0OHkB82rO0ODBW6awCgwgh5AKgwQh4AKoyQB4AKI+T7idmnAApGyPcLs08B9AAh3y97W18GAHLSVcjbPtP2I7anbNea9n3W9mbbj9l+e3fVrCBmnwLogW4nQ22U9C5J32rcaHuJpJWSXivpYEm32T4iInZ1WV51MPsUQA901ZKPiE0R8VjGrhWSroyI5yPi55I2S1raTVmVxM05ABSsqD75QyQ93fB6LN32ArZX2R61PTpBvzQA5Gra7hrbt0k6KGPX+RFxfatvy9iWOU4wIlZLWi1JtVqNsYQAkKNpQz4iTp7B+45JWtjweoGkLTN4HwBAF4rqrlkjaaXt/WwfKulwST8tqKzqYHIUgJx1O4TyDNtjko6V9CPbN0tSRDwi6WpJ/yPpx5I+xsiaaTA5CkABHCVqNdZqtRgdHe13NfpjfDwJ+MnJZOz82BjrvgNoi+31EVHL2seM17JgchSAAnBnqLJgchSAAhDyZcKt+QDkjO6aojFiBkAfEfJ5ag50RswA6DNCPi9Zgc5ywgD6jJDPS1agM2IGQJ8R8nnJCvT6iJmxMWntWkbMAOg5RtfkpdUQSEbMAOgjQj5PBDqAkqG7plMMiQQwQAj5TkxOSm96E0MiAQwMumvaNTUlLVsm3Xtv8vruu5MW/axZLEMAoLRoybdrYkK6//7dr2s16b3vpVUPoNQI+XbNnSsdf7w0e7Z0zDHStddK99zDRCcApUZ3Tbuah0hKyXj4u+9mohOA0iLkO9E8RJKlgQGUXLe3/zvT9iO2p2zXGrYvtv0H2xvSxyXdVzUHeQ9/rIc+AQ+gpLrtk98o6V2S7srY90REHJU+zumynO6xIiSAIdRVyEfEpoh4LK/K5K6x5c6KkACGUJGjaw61/TPbd9pe1uog26tsj9oencgzeJtb7nPmsCIkgKEz7YVX27dJOihj1/kRcX2Lb9sqaVFE7LD9Bkk/tP3aiHiu+cCIWC1ptSTVarWZd5bX12+vXwRtbrlv386FUgBDZ9qWfEScHBGvy3i0CnhFxPMRsSN9vl7SE5KOyK/aTbL627OW/uVCKYAhU8gQStsjkp6NiF22D5N0uKQniyhLUnZ/+7x5tNwBDL1uh1CeYXtM0rGSfmT75nTXmyU9ZPtBST+QdE5EPNtdVfei1R2YaLkDGHKOEi2ZW6vVYnR0dGbf3NwnDwBDwvb6iKhl7avO2jX1VnsE670DQKo6IS8x4QkAmlQr5JnwBAB7qFbIt7oACwBDqlqrUDYvB8wFWABDrlohL71wOWAAGGLV6q4BAOyBkAeACiPkAaDCCHkAqDBCHgAqjJAHgAor1QJltick/WKaw+ZI2t6D6nSDOuaj7HUse/0k6piXstfxVRExkrWjVCHfDtujrVZbKwvqmI+y17Hs9ZOoY14GoY6t0F0DABVGyANAhQ1iyK/udwXaQB3zUfY6lr1+EnXMyyDUMdPA9ckDANo3iC15AECbCHkAqLBShrztM20/YnvKdq1p32dtb7b9mO23t/j+Q23fZ/tx21fZ3rfg+l5le0P6eMr2hhbHPWX74fS4Gd6xfMZ1vMD2Mw31PK3Fcaek53az7fN6WL9/sf2o7YdsX2f75S2O6/k5nO6c2N4v/QxsTj93i3tRr4byF9q+w/am9PfmExnHLLf9m4af/+d7Wce0Dnv92Tnxr+l5fMj263tcvyMbzs8G28/Z/mTTMX0/jx2LiNI9JL1G0pGS1kqqNWxfIulBSftJOlTSE5JmZ3z/1ZJWps8vkfSRHtb9q5I+32LfU5Lm9OmcXiDpM9McMzs9p4dJ2jc910t6VL+3Sdonff4VSV8pwzls55xI+qikS9LnKyVd1eOf7XxJr0+fHyDpfzPquFzSDf347LX7s5N0mqSbJFnSMZLu62NdZ0v6pZJJRqU6j50+StmSj4hNEfFYxq4Vkq6MiOcj4ueSNkta2niAbUt6i6QfpJsul3R6kfVtKvs9kq7oRXkFWCppc0Q8GRF/lHSlknNeuIi4JSIm05f3SlrQi3Lb0M45WaHkcyYln7uT0s9CT0TE1oh4IH2+U9ImSYf0qvwcrZD03UjcK+nltuf3qS4nSXoiIqabgV96pQz5vThE0tMNr8f0wg/zKyX9uiEwso4pyjJJ4xHxeIv9IekW2+ttr+pRnRqdm/4bfJntAzP2t3N+e+EDSlp0WXp9Dts5J386Jv3c/UbJ57Dn0q6ioyXdl7H7WNsP2r7J9mt7WrHEdD+7snz+pOQ/slaNtX6fx4707fZ/tm+TdFDGrvMj4vpW35axrXkMaDvHdKzN+p6lvbfij4+ILbbnSrrV9qMRcVe3dWunjpIulvQlJefiS0q6lT7Q/BYZ35vbGNt2zqHt8yVNSvpei7cp9Bxm6NtnrlO2XyrpGkmfjIjnmnY/oKTr4bfp9ZgfSjq8x1Wc7mdXlvO4r6R3Svpsxu4ynMeO9C3kI+LkGXzbmKSFDa8XSNrSdMx2Jf/m7ZO2qrKO6dh09bW9j6R3SXrDXt5jS/p1m+3rlHQF5BZQ7Z5T25dKuiFjVzvnd8baOIdnS3qHpJMi7QDNeI9Cz2GGds5J/Zix9HPwMknPFlinF7D9IiUB/72IuLZ5f2PoR8SNtv/N9pyI6NmiW2387Ar9/HXgVEkPRMR4844ynMdODVp3zRpJK9PRDIcq+Qv608YD0nC4Q9K7001nS2r1n0GeTpb0aESMZe20/RLbB9SfK7nQuLEH9aqX39i3eUaLsu+XdLiT0Un7KvmXdU2P6neKpH+Q9M6I+H2LY/pxDts5J2uUfM6k5HP3k1Z/pIqQ9v9/W9KmiPhai2MOql8nsL1Uye/+jh7WsZ2f3RpJf5eOsjlG0m8iYmuv6tig5X/k/T6PM9LvK79ZDyUhNCbpeUnjkm5u2He+ktEOj0k6tWH7jZIOTp8fpiT8N0v6T0n79aDO35F0TtO2gyXd2FCnB9PHI0q6KHp5Tv9D0sOSHlLyyzS/uY7p69OUjM54opd1TH9WT0vakD4uaa5fv85h1jmR9EUlf5Ak6c/Sz9nm9HN3WI9/tm9S0q3xUMP5O03SOfXPpKRz03P2oJIL28f1uI6ZP7umOlrSN9Pz/LAaRtb1sJ77KwntlzVsK815nMmDZQ0AoMIGrbsGANABQh4AKoyQB4AKI+QBoMIIeQCoMEIeACqMkAeACvt/V7N/5YFcCdgAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "eval_x, eval_label = get_data(50)\n", + "x1, y1 = eval_x.asnumpy(), eval_label.asnumpy()\n", + "plt.scatter(x1, y1, color=\"red\", s=5)\n", + "plt.title(\"Eval_data\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 定义前向传播网络" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 初始化网络模型" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "使用`nn.Dense`定义了网络模型,即为线性模型,\n", + "\n", + "$$y=wx+b\\tag{1}$$\n", + "\n", + "其中,权重值$w$对应`weight`,$b$对应`bias`,并将其打印出来。" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "weight: -0.00034249047 bias: -0.019308656\n" + ] + } + ], + "source": [ + "from mindspore.common.initializer import TruncatedNormal\n", + "from mindspore import nn\n", + "\n", + "net = nn.Dense(1,1,TruncatedNormal(0.02),TruncatedNormal(0.02))\n", + "print(\"weight:\", net.weight.default_input[0][0], \"bias:\", net.bias.default_input[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 查看初始化的网络模型" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "我们将验证数据集和初始化的模型函数可视化。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAEICAYAAAC6fYRZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAaVUlEQVR4nO3df7QcZZ3n8feHIMiIC8F7EwJJDDj4I7pnUdsoYCACg8g4A3hGN8weZdXdLI7sGXdnzi4uZx1Gd+aoM4575owjG0eEmREVUSSDKARIhFF+3SCEhMgQFOWSeHNDVBh12b253/2jqjeVtvre/lHVP6o/r3P6dHdVdT1PV/f99nO/9TxPKSIwM7NqOqTfFTAzs/I4yJuZVZiDvJlZhTnIm5lVmIO8mVmFOcibmVWYg7wNFUmbJf27FrddI2my7DqVQdITks7udz1s+DnIWynSIPVLSf+cuf1Vv+vVjKR/K+kf+12Pskm6WtL/6Hc9rHcO7XcFrNJ+KyJu63clzEaZW/LWU5IOl/RTSa/KLBtPW/2LJC2UdJOkaUk/SR8vbXHfR6Qt1Z9IegR4XcP6yyQ9LulZSY9IujBd/grgSuCU9D+On6bLf1PSdyU9I+lJSVfMUfac9U7TTB+R9O20/FsljWXWv1PSDyU9Lenyed7n1ZI+Jenr6b7ulfSSzPqXS9ooaZ+kRyW9I12+Dvg3wH9J3+c/tHJcbbg5yFtPRcRzwFeBizKL3wF8KyL2kHwnPwe8GFgO/BJoNc3zR8BL0tubgYsb1j8OrAaOAv4Y+HtJSyJiB3AJcHdEHBkRR6fb/xx4F3A08JvA+yRd0KTsVur9u8C7gUXAYcAfAkhaCXwaeCdwHPAiYL4ftovS97AQ2An8SbqvFwAbgWvTci4C/lrSKyNiPfB54OPp+/ytecqwCnCQtzJ9LW2112//Pl1+LQcH+d9NlxERT0fEVyLiFxHxLEnwOqPF8t4B/ElE7IuIJ4G/zK6MiC9HxK6ImI2ILwGPAaua7SwiNkfEw+n2W4EvNKtLi/X+XET8U0T8ErgOODld/jvATRFxZ/oj+N+B2Xne61cj4r6ImCEJ3PV9vRV4IiI+FxEzEfEA8JW0DBtBzslbmS5okpO/AzhC0uuBH5MEqBsAJP0a8EngXJJWKsALJS2IiP3zlHcc8GTm+Q+zKyW9C/jPwIp00ZHAGE2k9fso8CqSlvfhwJebbNtKvX+ceckv0vJ/pd4R8XNJTzer1zz7ejHw+nrKKXUo8Hfz7M8qyi1567mImCVpyV5E0oq/KW39AvwB8DLg9RHxL4DT0+VqYde7gWWZ58vrDyS9GPgMcCnwojQlsy2z37zpWK8FNgDLIuIokrx9s3oUVu/0B+NFLbwuz5Mkqa+jM7cjI+J96XpPOztiHOStX64F/jXJicBrM8tfSJLP/qmkY0jy7K26DvhgehJ0KfAfM+teQBLgpgEkvZukhV43BSyVdFhDXfZFxP+WtIrkB6mZbup9PfBWSW9My/8wnf9t3gS8ND2R+7z09rr05DIk7/PEDvdtQ8hB3sr0Dw395G+or4iIe0lObB4HfCPzmv8JHAHsBe4BvtlGeX9MkqL5AXArmRRFRDwCfAK4myTQ/Uvg25nX3gFsB34saW+67PeAD0t6FvgQyY9IMx3XOyK2A+8n+bHbDfwE6GgQV/of0TnAWmAXSVrnYySpJoDPAivTcyRf66QMGy7yRUPMzKrLLXkzswpzkDczqzAHeTOzCnOQNzOrsIEaDDU2NhYrVqzodzXMzIbKli1b9kbEeN66gQryK1asYGJiot/VMDMbKpJ+2Gyd0zVmZhXmIG9mVmEO8mZmFeYgb2ZWYQ7yZmYV5iBvZlZhDvJmZkWZnYWpKRigiR8d5M3MijA7C296EyxdCmvWJM8HgIO8mVkRpqfhO9+BmZnkfnq63zUCHOTNzIqxaBGceiocemhyv2hRv2sEDNi0BmZmQ0uCTZuSFvyiRcnzAeAgb2ZWlEMOgcWL+12Lg3SdrpG0TNImSTskbZf0++nyYyRtlPRYer+w++qamVk7isjJzwB/EBGvAN4AvF/SSuAy4PaIOAm4PX1uZmY91HWQj4jdEfFA+vhZYAdwPHA+cE262TXABd2WZWZm7Sm0d42kFcCrgXuBxRGxG5IfAiD3VLOkdZImJE1MD0iXIzOzqigsyEs6EvgK8IGIeKbV10XE+oioRURtfDz3wiZmZtahQoK8pOeRBPjPR8RX08VTkpak65cAe4ooy8zMWldE7xoBnwV2RMRfZFZtAC5OH18M3NhtWWZm1p4i+smfBrwTeFjSg+my/wZ8FLhO0nuBHwFvL6AsMzNrQ9dBPiL+EWg2tOusbvdvZmad89w1ZmYV5iBvZtXWyhzvAzgPfFEc5M2sumZm4LTT4Pjjm8/xPqDzwBfFQd7Mqml2Fk4/He65B/bvh29/O3+O9wGdB74oDvJmVk3T03D//Qeev+51+XO8D+g88EVxkDezasoG71NOSVryeXO81+eBn5yEzZsHZh74ong+eTOrpnYu4jGA88AXxUHezKqrwsG7VU7XmNloqHA3ybk4yJtZ9VW8m+RcHOTNrPoq3k1yLg7yZlZ9Fe8mORefeDWz4TY7O38PmnZ62lSMW/JmNrzaybXXe9qMUIAHB3kzGwSt9nxp3G6Ec+2tcpA3s/5qtTWet90I59pbpRigPqO1Wi0mJib6XQ0z66WpqSRwz8wkwXpyMn8AU7PtWsnJV5ykLRFRy1tX1IW8r5K0R9K2zLIrJD0l6cH0dl4RZZlZxbTaGh8bg1oNFiw4eLsRzbW3qqjeNVcDfwX8bcPyT0bEnxdUhplVUSs9X2Zn4cwzYWICVq2CO+5wUG9RIS35iLgT2FfEvsxsBM3XGs+eYL3/fti7t7f1G2Jln3i9VNLWNJ2zMG8DSeskTUiamPaZcTPL4xOsHSszyH8aeAlwMrAb+ETeRhGxPiJqEVEbHx8vsTpmNrQqPud7mUoL8hExFRH7I2IW+AywqqyyzGwE+ARrR0oL8pKWZJ5eCGxrtq2ZmZWjkN41kr4ArAHGJE0CfwSskXQyEMATwH8ooiwzM2tdIUE+Ii7KWfzZIvZtZmad87QGZmYV5iBvZlZhDvJmZhXmIG9mVmEO8mZmFeYgb2bla/WiIFY4B3kzK1c7l+izwjnIm1m5fIm+vnKQN7NyeQbJvirqoiFmZvlauSiIlcZB3szKV59B0nrO6RozswpzkDczqzAHeTOzCnOQNzOrMAd5s1HlUagjwUHebBR5FOrIKCTIS7pK0h5J2zLLjpG0UdJj6f3CIsoyswJ4FOrIKKolfzVwbsOyy4DbI+Ik4Pb0uZkNAo9CHRmFBPmIuBPY17D4fOCa9PE1wAVFlGVmHcjm32dnYc8euOMOmJyEzZs9CrXCyszJL46I3QDpfW5TQdI6SROSJqb9L6NZ8bL59zPOOPD4zDNhfNwBvuL6fuI1ItZHRC0iauPj4/2ujln1NObfnYsfKWUG+SlJSwDS+z0llmVmzTTm31vNxbuLZSWUOUHZBuBi4KPp/Y0llmVmzTTOAhkx/4yQ9RTPd76T/Bhs2pRMMmZDp6gulF8A7gZeJmlS0ntJgvtvSHoM+I30uZn1SrYlXp8FUjr4cTPuYlkZhbTkI+KiJqvOKmL/Ztamblvi9RRP/fXuYjm0PJ+82TCbnc1PveS1xNuZz90X+qgMJ9nMhlWzqQlmZ5MUTbeDnVpJ69jAc5A3G1Z5rfV64F+2LAn0P/qRBzuNOAd5s2GVNzVBNvDffXfSGneAH2nOyZsNq3refGrqQCDv5oRps/y+DTW35M2G3dq1SXpmzZokRbNpU/tz0njq4cpykDcbZlNTcNddSXrmrruS552cMHW/+MpykDcbZtKBaQciOk+zeOrhynKQNxsmjfPJLF4Mq1fDggXJfTt94bPq+X1PPVw5DvJmwyIvby4lQfmpp+Bb3+ouOLtffCU5yJsNi2Z5cwdnm4ODvNmwcN7cOuB+8mbDwvPJWAfckjcbFK1cpMOpGWuTg7zZIPBgJCuJg7zZIPBgJCuJg7zZIPBJVStJ6SdeJT0BPAvsB2YiolZ2mWYDLW8iMJ9UtZL0qiX/pog42QHeRt5cuXefVLUSOF1j1o1WesRkOfduPdaLIB/ArZK2SFrXg/LMeqOTHjHOvVuP9WIw1GkRsUvSImCjpO9FxJ31lWngXwewfPnyHlTHrCCdXCzbuXfrsdJb8hGxK73fA9wArGpYvz4iahFRGx8fL7s6ZsXptFXu3Lv1UKkteUkvAA6JiGfTx+cAHy6zTLOecavchkDZ6ZrFwA1KvvyHAtdGxDdLLtOsd+qtcrMBVWqQj4jvA/+qzDLMzKw5d6E0M6swB3kzswpzkDczqzAHeTOzCnOQNzOrMAd5M7MKc5A3M6swB3kzswpzkDczqzAHeTOzCnOQNzOrMAd5Gw3tXsHJrCIc5K3aZmdh9+7kyk3tXMHJrCJ6cWUos/6oX56vfvUmaP0KTmYV4SBv1ZW9PJ+UzP3u66raiHG6xqore3m+1athchI2b/YVnGykuCVv1eXL85k5yFvF+fJ8NuJKT9dIOlfSo5J2Srqs7PLMzOyAUoO8pAXAp4C3ACuBiyStLLNMMzM7oOx0zSpgZ3pBbyR9ETgfeKTIQrZtg7Vri9xj5wYl7et6HMz1OJjrcbBBqMc558Cf/mnx+y07yB8PPJl5Pgm8PruBpHXAOoDly5d3VMjznw8vf3mHNSzQoAymdD0O5noczPU42KDU46ijytlv2UE+7/fxoEMaEeuB9QC1Wq2jw/3rvw7XX9/JK23gzM66N4xZgco+8ToJLMs8XwrsKrlMG1b1EaqefsCsMGUH+fuBkySdIOkwYC2woeQybVhlR6jWpx8ws66UGuQjYga4FLgF2AFcFxHbyyzThlh2hOqpp8LYmGeONOtS6YOhIuJm4Oayy7EKyI5QHRuDM89MWvSnnposP8SzcJi1y381NljqI1T37nXqxqwADvI2mBpTN5450qwjnrvGBpMnFzMrhIO8DS5PLmbWNadrzMwqzEHezKzCHOStOLOz7tduNmAc5K0Y2SkJzjgDdu92sDcbAA7yVozslAR33QXLl3v+GbMB4CBvxaj3a1+wIOnu6EFMZgPBQd6KUe/XPjkJq1d7EJPZgHA/eSvOIYfAscd6EJPZAHFL3opXH8TUToB3zxyzUjjIW//5YiFmpXGQtwMaW9NFta7n248vFmJWGgd5SzS2pmdmimldt9JK94yTZqVRDFAOtFarxcTERL+rMZqmppJAPDOTBNvvfhde/eoDzycnO5ssrHG/zfbjC3ibdUzSloio5a0rrSUv6QpJT0l6ML2dV1ZZVoDG1vTKlcW0rlttpXdystbM5lV2F8pPRsSfl1yGFSFv/vYiukJ6XnizvnI/eTugcf72ouZz97zwZn1T9onXSyVtlXSVpIV5G0haJ2lC0sS0e1UMF/dtNxt4XQV5SbdJ2pZzOx/4NPAS4GRgN/CJvH1ExPqIqEVEbXx8vJvqWC+5b7vZUOgqXRMRZ7eynaTPADd1U5YNmLy+7U7JmA2cMnvXLMk8vRDYVlZZ1qW8tMt8qRj3bTcbCmXm5D8u6WFJW4E3Af+pxLKsU3lpl1ZSMdlZJzdvdq8ZswHlwVCjLm+wErQ2gMnMBkJfBkPZkMhLuzgVY1YZ7ic/6poNVvIAJrNKcEu+6lrpy543pYCnGTCrBAf5KnNfdrOR5yBfZXPN0+7RqmYjwUG+ypqdQHUL32xk+MRrlTU7qerRqmYjwy35qss7geoukmYjwy35UeQ53s1GhoP8qPIc72YjwekaM7MKc5A3M6swB3kzswpzkDczqzAH+UHhEahmVgIH+UHgEahmVhIH+UEw1xwzZmZd6CrIS3q7pO2SZiXVGtZ9UNJOSY9KenN31aw4j0A1s5J0OxhqG/A24H9lF0paCawFXgkcB9wm6aURsb/L8qrJI1DNrCRdteQjYkdEPJqz6nzgixHxXET8ANgJrOqmrMrzRTrMrARl5eSPB57MPJ9Ml/0KSeskTUiamHYu2sysUPOmayTdBhybs+ryiLix2ctyluX2DYyI9cB6gFqt5v6DZmYFmjfIR8TZHex3EliWeb4U2NXBfszMrAtlpWs2AGslHS7pBOAk4L6SyqomD44yswJ024XyQkmTwCnA1yXdAhAR24HrgEeAbwLvd8+aNnhwlJkVRDFALcVarRYTExP9rkb/TU0lAX5mJuk7Pznpud/NrClJWyKilrfOI14HkQdHmVlBfGWoQeTBUWZWEAf5QeXL85lZAZyu6TX3mjGzHnKQL1NjQHevGTPrMQf5suQFdE8pbGY95iBflryA7l4zZtZjDvJlyQvo9V4zk5OwebN7zZhZ6dy7pizNukG614yZ9ZCDfJkc0M2sz5yuKYK7RZrZgHKQ79bMDLzxje4WaWYDyUG+G7OzsHo13H33gV40jzziFr2ZDQwH+W5MT8P99x94fsQRcPLJbtGb2cBwkO/GokVw2mmwYAG89rXw85/D/v0e6GRmA8NBvhv1bpJPPQX33Zfk5j3QycwGiLtQdivbTdLTA5vZgOn28n9vl7Rd0qykWmb5Ckm/lPRgeruy+6qWqKgukPWA7wBvZgOi23TNNuBtwJ056x6PiJPT2yVdllMezwxpZhXWVZCPiB0R8WhRlemZbMvdM0OaWYWVeeL1BEnflfQtSaubbSRpnaQJSRPTvQiwjS33sTHPDGlmlTXviVdJtwHH5qy6PCJubPKy3cDyiHha0muBr0l6ZUQ807hhRKwH1gPUarXiRxHV53GvnwxtbLnv3esTpmZWWfO25CPi7Ih4Vc6tWYAnIp6LiKfTx1uAx4GXFlftFuXl2/OmAPYJUzOrqFK6UEoaB/ZFxH5JJwInAd8vo6w55eXbFy92y93MRka3XSgvlDQJnAJ8XdIt6arTga2SHgKuBy6JiH3dVbUDza7E5Ja7mY0IxQBNplWr1WJiYqLYnTbm5M3MKkbSloio5a2r/rQG9VZ7hOd8N7ORU/0gDx7wZGYjazSCvAc8mdmIGo0g3+wErJlZxY3GLJT1KYF9AtbMRsxoBHk4eEpgM7MRMRrpGjOzEeUgb2ZWYQ7yZmYV5iBvZlZhDvJmZhXmIG9mVmEDNUGZpGngh13sYgzYW1B1iuR6tcf1at+g1s31ak+n9XpxRIznrRioIN8tSRPNZmLrJ9erPa5X+wa1bq5Xe8qol9M1ZmYV5iBvZlZhVQvy6/tdgSZcr/a4Xu0b1Lq5Xu0pvF6VysmbmdnBqtaSNzOzDAd5M7MKG6ogL+ntkrZLmpVUa1j3QUk7JT0q6c1NXn+CpHslPSbpS5IOK6meX5L0YHp7QtKDTbZ7QtLD6XYFX8E8t7wrJD2Vqdt5TbY7Nz2OOyVd1oN6/Zmk70naKukGSUc32a4nx2u+9y/p8PQz3pl+n1aUVZdMmcskbZK0I/0b+P2cbdZI+lnm8/1Q2fXKlD3nZ6PEX6bHbKuk1/SgTi/LHIsHJT0j6QMN2/TkmEm6StIeSdsyy46RtDGNRxslLWzy2ovTbR6TdHHbhUfE0NyAVwAvAzYDtczylcBDwOHACcDjwIKc118HrE0fXwm8rwd1/gTwoSbrngDGenj8rgD+cJ5tFqTH70TgsPS4riy5XucAh6aPPwZ8rF/Hq5X3D/wecGX6eC3wpR58dkuA16SPXwj8U0691gA39er71M5nA5wHfAMQ8Abg3h7XbwHwY5JBQz0/ZsDpwGuAbZllHwcuSx9flve9B44Bvp/eL0wfL2yn7KFqyUfEjoh4NGfV+cAXI+K5iPgBsBNYld1AkoAzgevTRdcAF5RZ37TMdwBfKLOcgq0CdkbE9yPi/wBfJDm+pYmIWyNiJn16D7C0zPLm0cr7P5/k+wPJ9+ms9LMuTUTsjogH0sfPAjuA48sss2DnA38biXuAoyUt6WH5ZwGPR0Q3I+o7FhF3AvsaFme/R83i0ZuBjRGxLyJ+AmwEzm2n7KEK8nM4Hngy83ySX/0DeBHw00wwydumaKuBqYh4rMn6AG6VtEXSupLrUndp+u/yVU3+PWzlWJbpPSQtvjy9OF6tvP//v036ffoZyferJ9L00KuBe3NWnyLpIUnfkPTKXtWJ+T+bfn+v1tK8sdWvY7Y4InZD8iMO5F18uuvjNnCX/5N0G3BszqrLI+LGZi/LWdbYN7SVbVrWYj0vYu5W/GkRsUvSImCjpO+lv/gdm6tewKeBj5C874+QpJLe07iLnNd23c+2leMl6XJgBvh8k90UfrzyqpqzrNTvUjskHQl8BfhARDzTsPoBknTEP6fnW74GnNSLejH/Z9PPY3YY8NvAB3NW9/OYtaLr4zZwQT4izu7gZZPAsszzpcCuhm32kvyLeGja+srbpmXz1VPSocDbgNfOsY9d6f0eSTeQpAq6ClqtHj9JnwFuylnVyrEsvF7pCaW3AmdFmozM2UfhxytHK++/vs1k+jkfxa/+K144Sc8jCfCfj4ivNq7PBv2IuFnSX0sai4jSJ+Jq4bMp5XvVorcAD0TEVOOKfh4zYErSkojYnaau9uRsM0ly3qBuKck5yZZVJV2zAVib9no4geSX+L7sBmng2AT8TrroYqDZfwZFOBv4XkRM5q2U9AJJL6w/Jjn5uC1v26I05EAvbFLe/cBJSnoiHUbyb+6Gkut1LvBfgd+OiF802aZXx6uV97+B5PsDyffpjmY/TEVJc/6fBXZExF802ebY+rkBSatI/r6fLrNeaVmtfDYbgHelvWzeAPysnqrogab/UffrmKWy36Nm8egW4BxJC9P06jnpstaVfVa5yBtJYJoEngOmgFsy6y4n6RXxKPCWzPKbgePSxyeSBP+dwJeBw0us69XAJQ3LjgNuztTlofS2nSRtUfbx+zvgYWBr+gVb0liv9Pl5JL03Hu9RvXaS5B0fTG9XNtarl8cr7/0DHyb5EQJ4fvr92Zl+n07swTF6I8m/6Vszx+k84JL69wy4ND02D5GcwD617HrN9dk01E3Ap9Jj+jCZ3nEl1+3XSIL2UZllPT9mJD8yu4H/m8aw95Kcx7kdeCy9Pybdtgb8Tea170m/azuBd7dbtqc1MDOrsKqka8zMLIeDvJlZhTnIm5lVmIO8mVmFOcibmVWYg7yZWYU5yJuZVdj/A9pJHOWc/ZqnAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "x = np.arange(-10, 10, 0.1)\n", + "y = x * (net.weight.default_input[0][0].asnumpy()) + (net.bias.default_input[0].asnumpy())\n", + "plt.scatter(x1, y1, color=\"red\", s=5)\n", + "plt.plot(x, y, \"blue\")\n", + "plt.title(\"Eval data and net\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "红色的点:为之前生成的50组验证数据集。\n", + "\n", + "蓝色的线:初始化的模型网络。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 定义损失函数" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "我们的网络模型表达式为:\n", + "\n", + "$$h(x)=wx+b\\tag{2}$$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "一般地,数学上对线性回归模型采用均方差的方式来判断模型是否拟合得很好,即均方差的值$J(w)$值越小,函数模型便拟合得越好,验证数据代入后,预测得到的y值就越准确。公式2对应m个数据的均方差公式为:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "$$J(w)=\\frac{1}{m}\\sum_{i=1}^m(h(x_i)-y^{(i)})^2\\tag{3}$$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "为了方便后续的计算,我们采用0.5倍的均方差的表达式来进行计算,均方差值整体缩小至0.5倍的计算方式对判断模型拟合的好坏没有影响。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "$$J(w)=\\frac{1}{2m}\\sum_{i=1}^m(h(x_i)-y^{(i)})^2\\tag{4}$$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "公式4即为网络训练中的损失函数,其中参数:\n", + "\n", + "- $J(w)$为均方差。\n", + "\n", + "- $m$为样本数据的数量。\n", + "\n", + "- $h(x_i)$为第$i$个数据的$x_i$值代入模型网络(公式2)后的预测值。\n", + "\n", + "- $y^{(i)}$为第$i$个数据中的$y$值(label值)。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在MindSpore中定义损失函数的方法如下。" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from mindspore.ops import operations as P\n", + "\n", + "class MyLoss(nn.loss.loss._Loss):\n", + " def __init__(self,reduction='mean'):\n", + " super().__init__(reduction)\n", + " self.square = P.Square()\n", + " def construct(self, data, label):\n", + " x = self.square(data-label) * 0.5\n", + " return self.get_loss(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "其中:\n", + "\n", + "- `nn.loss.loss._Loss`:是MindSpore自定义loss算子的一个基类。\n", + "\n", + "- `P.Square`:MindSpore训练的框架中的平方算子,算子需要注册过才能在框架的计算图中使用。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 损失函数与网络结合\n", + "\n", + "接下来我们需要将loss函数的表达式和网络net关联在一起,在MindSpore中需要`nn.WithLossCell`,实现方法如下:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "criterion = MyLoss()\n", + "loss_opeartion = nn.WithLossCell(net, criterion) " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "其中:\n", + "\n", + "- `net`:网络模型。\n", + "\n", + "- `criterion`:即为实例化的loss函数。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "上述从数据代入到计算出loss值的过程为AI训练中的前向传播过程。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 定义反向传播网络" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "有了损失函数后,我们如何使得损失函数最小呢?我们可以将公式1代入到损失函数公式4中展开:\n", + "\n", + "$$J(w,b)=\\frac{1}{2m}\\sum_{i=1}^m(wx_i+b-y^{(i)})^2\\tag{5}$$" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "公式5可以将$J(w)$看作为凹函数,对权重值$w$微分可求得:\n", + "\n", + "$$\\frac{\\partial{J(w)}}{\\partial{w}}=\\frac{1}{m}\\sum_{i=1}^mx_i(wx_i+b-y^{(i)})\\tag{6}$$\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "由凹函数的特性可以知道,当公式6等于0时,损失函数有最小值:\n", + "\n", + "$$\\sum_{i=1}^mx_i(wx_i+b-y^{(i)})=0\\tag{7}$$ \n", + "\n", + "假设有一个$w_{min}$使得公式7成立。我们如何将初始的权重$w_{s}$逐步的变成$w_{min}$,在这里采取迭代法,也就是梯度下降方法\n", + "\n", + "当权重$w_{s}w_{min}$,权重值需要左移即权重值变小接近$w_{min}$,才能使得损失函数逐步的变小,由凹函数的性质可知,在$w_{s}$处的导数为正(损失函数在$w_{min}$右边单调上升),公式8的值为正。其权重的更新公式为:\n", + "\n", + "$$w_{ud}=w_{s}-\\alpha\\frac{\\partial{J(w_{s})}}{\\partial{w}}\\tag{10}$$\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "当$w_{s}=w_{min}$时,到$\\frac{\\partial{J(w_{s})}}{\\partial{w}}$=0,即梯度消失,其表达式也可写为公式9的样式。\n", + "\n", + "在考虑了全区间的情况后,可以得出权重$w$的更新公式即为:\n", + "\n", + "$$w_{ud}=w_{s}-\\alpha\\frac{\\partial{J(w_{s})}}{\\partial{w}}\\tag{11}$$\n", + "\n", + "当权重$w$在更新的过程中假如临近$w_{min}$在增加或者减少一个$\\Delta{w}$,从左边或者右边越过了$w_{min}$,公式10都会使权重往反的方向移动,那么最终$w_{s}$的值会在$w_{min}$附近来回迭代,在实际训练中我们也是这样采用迭代的方式取得最优权重$w$,使得损失函数无限逼近局部最小值。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "同理:对于公式5中的另一个权重$b$容易得出其更新公式为:\n", + "\n", + "$$b_{ud}=b_{s}-\\alpha\\frac{\\partial{J(b_{s})}}{\\partial{b}}\\tag{12}$$\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "当所有的权重更新完成后,将新的权重赋值给初始权重:即$w_{s}$=$w_{ud}$,$b_{s}$=$b_{ud}$。将新的初始权重传递回到模型函数中,这样就完成了反向传播的过程。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> 当遇到多项式的回归模型时,上述梯度方法也适用,由于权重数量的增加,需要将权重的名称更新为$w_0,w_1,w_2,...,w_n$,引入矩阵的表达方式,公式将会更加简洁,这里就不多介绍了。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 实现梯度函数" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在MindSpore中的所有要编入计算图的类都需要继承`nn.Cell`算子。MindSpore的梯度计算函数采用如下方式。" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from mindspore.ops import composite as C\n", + "\n", + "class GradWrap(nn.Cell):\n", + " \"\"\" GradWrap definition \"\"\"\n", + " def __init__(self, network):\n", + " super().__init__(auto_prefix=False)\n", + " self.network = network\n", + " self.weights = ms.ParameterTuple(filter(lambda x: x.requires_grad,\n", + " network.get_parameters()))\n", + "\n", + " def construct(self, data, label):\n", + " weights = self.weights\n", + " return C.GradOperation('get_by_list', get_by_list=True) \\\n", + " (self.network, weights)(data, label)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "上述代码中`GradWrap`实现的是对各个权重的微分$\\frac{\\partial{J(w)}}{\\partial{w}}$,其展开式子参考公式6。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 反向传播更新权重" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`nn.RMSProp`为完成权重更新的函数,更新方式大致为公式10,但是考虑的因素更多,具体信息请参考官网说明:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "train_network = GradWrap(loss_opeartion) \n", + "train_network.set_train()\n", + "optim = nn.RMSProp(params=net.trainable_params(),learning_rate=0.02)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "通过以上操作,我们就完成了前向传播网络和反向传播网络的定义,接下来可以加载训练数据进行线性拟合了。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 定义模型拟合过程可视化函数" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "定义一个可视化函数`plot_model_and_datasets`,将模型函数和验证数据集打印出来,观察其变化。" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import time \n", + "\n", + "def plot_model_and_datasets(weight, bias, data_x, data_y):\n", + " x = np.arange(-10, 10, 0.1)\n", + " y = x * ((weight[0][0]).asnumpy()) + ((bias[0]).asnumpy())\n", + " plt.scatter(x1,y1,color=\"red\",s=5)\n", + " plt.scatter(data_x.asnumpy(), data_y.asnumpy(), color=\"black\", s=5)\n", + " plt.plot(x, y, \"blue\")\n", + " plt.show()\n", + " time.sleep(0.02)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "上述函数的参数:\n", + "\n", + "- `weight`:模型函数的权重,即$w$。\n", + "\n", + "- `bias`:模型函数的权重,既$b$。\n", + "\n", + "- `data_x`:训练数据的x值。\n", + "\n", + "- `data_y`:训练数据的y值。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> 可视化过程中,红色的点是验证数据集,黑色的点是单个batch的训练数据,蓝色的线条是正在训练的回归模型。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 执行训练" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "其训练过程如下:\n", + "\n", + "1. 设置训练的迭代次数`step_size`。\n", + "2. 设置单次迭代的训练数据量`batch_size`。\n", + "3. 正向传播训练`grads`。\n", + "4. 反向传播训练`optim`。\n", + "5. 图形展示模型函数和数据集。\n", + "6. 清除本轮迭代的输出`display.clear_output`,起到动态可视化效果。\n", + "\n", + "迭代完成后,输出网络模型的权重值$w和b$。" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss_value: 0.42879593\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD6CAYAAABEUDf/AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3deXxU1f3/8dcniWKLVEEzgbKIWytRkSWgBNCgKJafBbW14lJtQaCodf1ZQani169rxaUqUlBcqRY3QAURMKgQQAOyxwWXliCGYRFUVAz3fP+4N3aIk5BkZjKTyfv5eOSRyb135n5yM/nk5NxzPsecc4iISHrKSHYAIiKSOEryIiJpTEleRCSNKcmLiKQxJXkRkTSmJC8iksZiTvJm1tbMCs2sxMxWm9nlwfYxZrbezJYFH/1jD1dERGrDYh0nb2atgFbOuaVm1gxYApwO/A74yjl3V01f68ADD3Tt27ePKR4RkcZmyZIlm5xz2dH2ZcX64s65DcCG4PGXZlYCtK7La7Vv357i4uJYQxIRaVTM7N9V7Ytrn7yZtQc6A4uDTZea2Qozm2Rmzat4zjAzKzaz4nA4HM9wREQavbgleTPbF3geuMI5tx14CDgU6ITf0h8b7XnOuQnOuTznXF52dtT/NkREpI7ikuTNbC/8BD/ZOfcCgHOuzDm3yznnAROB7vE4l4iI1Fw8RtcY8AhQ4py7O2J7q4jDzgBWxXouERGpnZhvvAI9gd8DK81sWbDtOuAcM+sEOOBTYHgcziUiIrUQj9E18wGLsmtGrK8tIiKx0YxXEZE0piQvIpJEnudRVlZGohZwUpIXEUkSz/Po06cPbdq0oaCgAM/z4n4OJXkRkSQJh8MUFRVRXl5OUVERiZgQqiQvIpIkoVCIHj16kpl5CPn5+YRCobifIx5DKEVEpA4++MAwKyQnx+OllzLwpx3Fl1ryIiL1bOdOuOUWOOYYWLHCuPnmTJo1i3+CByV5EZG4qOkomcWLIS8PRo+GAQOgpAQGD4YENOIBJXkRkZjVZJTMV1/BFVdAjx6wZQtMnQpTpkDLlomNTUleRCRGexolM3MmHHkk3HcfjBgBa9bAwIH1E5uSvIhIjEKhEPn5+WRlZe02SiYchvPPh/79oWlTmD8fHnwQfvaz+otNo2tERGJkZhQWFhIuKyNkBg6efAquvBK2b4cbb4RRo6BJk/qPTS15EZEYVNxwNefIGTSIT9v0pt8BxVxwARx+OLz7LowZk5wED0ryIiJ1FnnD9YReBYx9qxtH7VrGwi+O4P5btjN/vt8Xn0zqrhERqaP/3nDN5a1F9/AWeZxmLzOu22O0HfVs9CLs9UwteRGROmrWLESrVo8CS9hrr0N4+p8e0z/rRttFzyZu4HstxWP5v7ZmVmhmJWa22swuD7a3MLPZZvZh8Ll57OGKiKSGwkI45hhj3brzGTToezZsaM6gczKwljkpk+AhPi35cuBq51wH4DjgEjPLBUYCc51zhwNzg69FRBq0rVth6FA48UTwPJg9G55++icccEDqJPZIMSd559wG59zS4PGXQAnQGhgIPB4c9jhweqznEhFJFufgueegQwd49FG45hpYuRL69k12ZNWL641XM2sPdAYWAznOuQ3g/yEws6g1NM1sGDAMoF27dvEMR0QkLtavh0sugWnToHNnmDEDunRJdlQ1E7cbr2a2L/A8cIVzbntNn+ecm+Ccy3PO5WVnZ8crHBGRmHkePPQQ5ObCrFlw553w9tsNJ8FDnFryZrYXfoKf7Jx7IdhcZmatglZ8K2BjPM4lIlIfSkr8vvcFC+Ckk+Af/4BDD012VLUXj9E1BjwClDjn7o7YNR24MHh8ITAt1nOJiCTazp1w883QqZNfSOzRR/2bqw0xwUN8WvI9gd8DK81sWbDtOuB2YIqZDQH+A5wVh3OJiCTMwoV+6331ahg0CO69F3Jykh1VbGJO8s65+VQ9r+ukWF9fRCSRPM/jk082ce+92Tz4oNG6Nbz0Epx2WrIjiw/NeBWRtOWVl1O2ahUuyiIe4Cf4jh1Hcdhh3/LAA46LL3asWZM+CR6U5EUkTXnl5fQ54ADaHH00Bc2b45WX77Z/40Y488xvWb36DuBLMunFX6//nGbNkhNvoijJi0haCq9ZQ9H27ZQDRdu3E16zBvAnNT3+OHTo4Jg+PQu4AehCPgv9WvBpRkleRNJSKDubfPwbj/nB1x9/DKecAn/4Axx22PdkZHQFbiaLnUw59lisod9ljUJJXkTSkrVsSWHv3pRmZjKnVwFjn2rJUUfB4sUwbhwUFe1Fz54t/CX7evQgp6gopQqLxYvqyYtIejIjY948Pnt9K/2vbcHSvxi//rWf4Nu0AQiW7AuHCYVCWBomeFBLXkTS1I4dcO2oDLqdegDr1xtTpvi1Z/wE78vIyCAnJydtEzyoJS8iaej112HYMPjoIxgyBP52h0fz8jAQIiWWa6pHasmLSNrYsgUGD/ZrzZj5yf7hCR7Nz+zjN+ELCvyqY42IkryINHjOwZQpfq33J56AkSNhxQro0wcIh6GoCMrL/c/hcLLDrVdK8iLSoK1bBwMHwtlnQ9u2UFwMt90GP/lJcEAoBPn5kJXlfw5FXdoibalPXkQaHM/zKCsL8/xz2YwaBbs8Y+xY47LL/Fy+GzN/QdZw2E/waXyTNRq15EWk4fA8vA0b6N79D/z85x/x58syOO7ruaw66hyuusL7cYKvkJHhl5NsZAke1JIXkVTgeXtuaXse351wCtcvOJ4l7mHgSzK4gCd5kpbvZvnPT8MZq7FSS15Eksvz/Dukexj9UjTjCzrP/ztj3Q1k8yyZmUfTa7/p5GRmNsq+9pqKS5I3s0lmttHMVkVsG2Nm681sWfDRPx7nEpE0s4fRL9u3w6WXOHoNaM7XTVrwSsav+bz3P1hfupR5mzdj69fDvHmNsiumJuLVkn8MODXK9nucc52CjxlxOpeIpJNqRr+89BLk5jrGjXNcxv2szvsD/UsnkPHGG+S0bIllZjbavvaaikufvHPuTTNrH4/XEpFGJsrol88/h8sv98e+H3VEOc9nFnDsriJYnOXfRFVSr7FE98lfamYrgu6c5tEOMLNhZlZsZsXhRjZJQUQCwegXhzFpkj+paepUf0HtJcuyOLZnVqMd5x6rRCb5h4BDgU7ABmBstIOccxOcc3nOubzs7OwEhiMiqWztWujb1681c/TRsHw5jB4NezcJWvqlpep7r4OEJXnnXJlzbpdzzgMmAt0TdS4RabjKy+HOO/3EXlwM48f7ufyIIyIOasTj3GOVsHHyZtbKObch+PIMYFV1x4tI47NkCVx0ESxbBqefDg88AK1bJzuq9BKXJG9mTwMFwIFmVgrcCBSYWSfAAZ8Cw+NxLhFp+HbsgBtvhLvv9rvYn38ezjwz2VGlp3iNrjknyuZH4vHaIpJeZs+G4cPhk09g6FC/q2b//ZMdVfrSjFcRqRebN/sLaJ9yij9QZt48mDBBCT7RlORFJKGcg2ee8YdFTp4M113n13o/4YRkR9Y4qECZiCTMf/4DI0bAjBnQrRvMmQMdOyY7qsZFLXkRibtdu+D++yE31++WueceWLhQCT4Z1JIXkbhatcofFrl4MfTrB+PHebRvGoaMxreIdipQS15EfuCvuFSGc67Wz/3uO7jhBujSxZ+9+tRTMPMVj/Z/bLyLaKcCJXkRAfwE36dPH9q0aUNBQQFeLRLy/PnQqZNfa+bss6GkBM47D2xT415EOxUoyYsIAOFwmKKiIsrLyykqKqImBQO3bfNvrPbuDd98A6++Ck8+CT+UoWrki2inAiV5EQEgFAqRn59PVlYW+fn5hPaQkKdN82+sTpgAV17p98X361fpIFNxsWTTjVcRAcDMKCwsJBwOEwqFsCoS8oYNcNll8Nxz/miZqVP94ZFVqiguJkmhlryI/CAjI4OcnJyoCd45ePhhf1LTSy/Brbf6VSOrTfCSdGrJi8geffABDBsGb7zhz1SdMAF+8YtkRyU1oZa8iFTp++/httv8bplly/zk/vrrSvANiVryIo2V5+22rmpl77zjT2pasQJ+8xt/BmurVkmIU2KilrxIY+R50Cf6JKWvv4arr4bjjoNNm+DFF/2brErwDVNcknywUPdGM1sVsa2Fmc02sw+Dz1EX8haRJAhHn6Q0axYcdZS/mMewYbBmjb9ikzRc8WrJPwacWmnbSGCuc+5wYG7wtYikgkqTlDZlhLjgAjj1VGjSBN58Ex56CPbbL9mBSqzikuSdc28CWyptHgg8Hjx+HFB7QCRZPA/KyvzykGVl/rbCQty6UiYPnUeHXOPpp2H0aP8Ga+/eyQ1X4ieRffI5FQt5B581n1kkGSr631u3hgMO+KEf/t//hv5/zOH83xuHHgpLl/q1Z/bZJ9kBSzwlfXSNmQ0DhgG0a9cuydGIpKGK/vddu2DbNnaRwf3z8xh9FGBw331wySWQmZnsQCUREtmSLzOzVgDB543RDnLOTXDO5Tnn8rJ/qGokInFT0f+emcmKffPpwSKu9MZy/AnG6tV+iYKoCb6ii6cOZYcldSQyyU8HLgweXwhMS+C5RKQqZnw7s5DRf/6Crt/O59MDuvLPyY5XXjEOOqiK51QzxFIalngNoXwaWAj80sxKzWwIcDtwspl9CJwcfC0i9SVoib8xz3FM5wxuuXdfzjvPKHk/g3POteoLQlYxxFIanrj0yTvnzqli10nxeH0RqSXP44vev+bahaczwQ3l4IMdr71mnHxyDZ9f0cVTVKQ68A1c0m+8ikgdVVOW4IXHtnNp0UTKyOFqu5ub5pxH00NqUe63og58NWUPpGFQWQORhqiKPvPPSj3O/H/f8psh+xNq+jWLM3tyV+9pND24Di3xijrwSvANmpK8SENUqc/cKwszYbxHh/Y7mDnDcfvB43knfDB566dpRaZGTklepCGKKEvwfqez6TMoxPARGXT13mElR3Ptuj+z1/bNaomLkrxIg2TGzlmF3HLVJo5Z8QQrVsAjDzvm9hrDYVn/rv3NUo2JT1tK8iIN0OLF0DXPGH3nfgzY+Rwlu37J4At3YfPqsGi2xsSnNSV5kQbkq6/giiugRw/YGi5nGgOYwtm0/PJDeO+9ut0s1Zj4tKYkL9JAzJwJRx7puO8+GPEnx5r3Mxmw35v+zv32g9zcur1wpbLDGhOfXpTkRVJcOAznnQf9+0PTTf9mfuYJPLi6gJ/9DH/pppUrYcsWvxVfFxVj4mvbzSMNgpK8SIpyDp58Ejp0gGefhRuv/op3vzuSnrve/G+3SlaWv5RTXRN8BY2JT1tK8iIp6JNPoF8/uOACOPxwePddGPO3pjTpmaduFakVlTUQSSHl5fD3v8Nf/+o3ru+/H0aMqCgFrFIDUntK8iKpwPNYPm8rF13bguJi47TTYNw4aNu20nEV3SoiNaTuGpEk++Zrj1Ht/0nXk/bjP8u38sw/PaZPj5LgRepASV4kiQoLoePRHrevO58LeIIS7wjOPjGsnhiJGyV5kSTYuhWGDoUTTwTPMpl99FVMyhpOi54ddENV4irhffJm9inwJbALKHfO5SX6nCKpyu3yeP7R7Vw6ej82bTL+8he48Ubjp/vcBeFrdUNV4q6+brz2cc5tqqdziSSM53mEw2FCoRBmVu3CHZWtX+dxSecFTNvcm877fsCMRYfRJa/in2ndUJXEUHeNSA15nkefPn1o06YNBQUFeOXlNSrs5Xnw0EPQIRdmbc7jTq7h7W860qWtasRI4tVHknfAa2a2xMyGVd5pZsPMrNjMisMqjCQpLBwOU1RURHl5OUVFRYTfe2+Phb1KSuD44+Hii6H7scaqboO5Jutesnoeq753qRf1keR7Oue6AL8CLjGz4yN3OucmOOfynHN52dnZ9RCOSN2EQiHy8/PJysoiPz+fUG5ulYW9du6Em2+GTp1gzRp49FGYPds4dNFk1YiRepXwPnnn3GfB541m9iLQHXgz0ecViTczo7CwcPc++SgzUBcu9EfOrF4NgwbBvfdGdLeb+t6lfiW0JW9mTc2sWcVj4BRgVSLPKZJIGRkZ5OTk+Ane3/BDYa8vv4TLLoOePWHbNnjpJXj6aeV0Sa5Et+RzgBeDX4gs4J/OuVcTfE6RevfKK36NmdJSuOQSuPVWaNYs2VGJJDjJO+c+Bo5J5DlEkmnjRrj8cnjmGX/NjgUL/FWbRFKFhlCK1IFz8Pjjfq33F16Am27yywErwUuqURVKkVr66CP4059gzhy//33iRD/Zi6QiteRFaqi8HO66C44+GhYv9ksBv/mmErykNrXkRWrg3Xfhootg6VIYMAAefNCf6CqS6tSSF6nGjh1w7bXQrRusX++vtTp1qhK8NBxqyYtUYe5cGD7c74MfMgT+9jdo3jzZUYnUjlryIpVs2QKDB0Pfvv4k1tdfh4cfVoKXhklJXiTgHEyZ4t9IfeIJGDkSVqzwC02KNFTqrhEB1q3zZ6q+9BJ07QqzZvnFxUQaOrXkpXHwPCgr85vrlTY/+KA/W3XOHBg7FhYtUoKX9KEkL+nN82DDBn9Rj0qLe6xeDb16waWX+jNVV62Cq67yKweLpAu9nSV9eZ7foR4s7OEB4QUL2G/dRm5/tOUPRcQefxx+/3uVd5f0pCQv6Ssc3i3B9wEW7DqWJrnfsmMHnHsu3HOPFmiS9KbuGklfodAPKzd9nHcib9mD7OItduwwJk/eyuTJSvCS/tSSl/QVrNw0/antXHzdfjjnMHuA/PwZnHPOzGRHJ1IvEt6SN7NTzex9M1trZiMTfT6RCp9/Dr8blMHAC/eneXOjqAg2bDibt96a+d+VnUTSXKKX/8sEHsRfxDsXOMfMchN5ThHnYNIkf1LTtGn+gtpLlkCPHpWW7hNpBBLdku8OrHXOfeyc2wk8AwxM8DmlEVu71i9HMGSIXxJ4+XIYPRr23jvZkYkkR6KTfGtgXcTXpcE2kbj6/nu44w4/sRcXw/jxMG8eHHFEsiMTSa5E33iN9n/xblMOzWwYMAygXbt2CQ5HUp7n+UMfQ6EaD1xfssSv9b5sGZx+OjzwALRWU0IESHxLvhRoG/F1G+CzyAOccxOcc3nOubzs7OwEhyMprWLyUqWZqVXZsQOuuQa6d/dvsj7/PLz4ohK8SKREJ/l3gMPN7GAz2xsYBExP8DmloYqYvERRkf91FWbPhqOO8pfjGzIESkrgzDPrMVaRBiKhSd45Vw5cCswCSoApzrnViTynNGARk5fIz4cDD/xRUbHNm+EPf4BTTvEPmzcPJkyA/fdPWtQiKS3hk6GcczOAGYk+j6SBYPIS4bCf4E880W/R5+fjXi/kmSkZXH45bN0K110Hf/0r7LNPsoMWSW2a8SqpJSMDcnL8FnzQdfOfBesY0e97ZsxtQrdufkngjh2THahIw6DaNZKaQiF29ejF/RmXk+tWMW/h3txzDyxcqAQvUhtqyUtKWrXauGjn6yz2jH79HOPHG+3bJzsqkYZHLXlJKd9+CzfcAF26wNq1xlNPwcyZSvAidaWWvKSMt96CoUPh/ffh/PPh7rtBUydEYqOWvCTdtm0wYgQcf7zfkn/1VXjySSV4kXhQkpf4qWKx7OpMm+Yvoj1hAlx5pb/Oar9+CYxRpJFRkpf4qChJ0Lq1P5Fp165qD9+wAX77W7/WzIEHwqJFfvfMvvvWU7wijYSSvMRHOAwLFvjJfdEi6N07au0Z5+Dhh/1a7y+/DLfe6leN7NYtCTGLNAJK8hIfodDumfqdd35Ue+aDD/zG/tCh0KkTrFgBo0bBXnvVc6wijYiSvMSHGcyfDz16/Lf2TLBK9vffw223+ZOYli3z+99ffx1+8YskxyzSCGgIpcRPZqaf6CPqwb/zjl/rfcUK+M1v4P77oVWrKM+tQx15EdkzteQlvoLaM1/vMK66Co47DjZt8uu8P/dcNQm+FnXkRaTmlOSlalGGRHqeR1lZGa7SMMnI7bNm+bXe77kHhg11rJm3kdMHVjOsshZ15EWkdpTk5b8ik3qU1rXnefTp04c2bdpQUFCAF7S4K7a3bn0MLVu+xqmnQpMm8OY8j4dKCtgvt3X1LfTKdeSDvnwRiZ1VbpElU15enisuLk52GI1TRVIP6rfzzDPQrp3fus7KgtJSyoA2bdpQXl5OVlYWpaWl5OTk8PnnZbRufQ2eNxbYnyuv/I5bb92XfbaV+X8kIl6DnJyqz68+eZE6MbMlzrm8aPsS1pI3szFmtt7MlgUf/RN1LomDoMvEKy+nbMECf7X1Sq3rUChEfn4+WVlZ5OfnEwqF+PRT+OMfQ3jeE8DHdO06jLFjm/qLedSmhV5RR14JXiSuEtaSN7MxwFfOubtq+hy15JPIObwTTqDPW29RZEZ+794Uzp1LxubNu7WuPc8jHA5zwAEhHnjAuP56f9ctt3icdVaYVq1CWGSiVgtdJOGqa8lrCKX4zAj/618UtWtHeXk5RUVFhDdvJqdS90pGRgZlZTn8+tf+fKf+/WHcODjooAwgSldMRQtdRJIi0TdeLzWzFWY2ycyaRzvAzIaZWbGZFYc1qiKpQi1b/qg7JtK338Lo0dC1K3z6KfzzKY+XHynjoHapc19HRHYXU3eNmc0BWkbZdT2wCNgEOOBmoJVzbnB1r6fumuSr6I4JhXbvdnnjDRg2zC9NcOGFMPZvHgf8NuJGbWGh32oXkXqXsO4a51zfGgYwEXg5lnNJ/cjIyNiti+aLL+Avf4GJE+Hgg+G11+Dkk4GyKGPb1S0jknISObomcm7jGcCqRJ1LYhStDrzn8cKkL8jNdTzyCFx9NaxcGSR40Nh2kQYikTde7zSzTvjdNZ8CwxN4LqmryuPjCwv57DO4tNN8Xtx8PJ2afsj0hYeS171Se8DM76LRyBmRlJawlrxz7vfOuaOdcx2dcwOccxsSdS6JQURJAW/BQibc/RUdcmHm5m7czrW8/W1H8g6q4oa4xraLpDzdKWvsgm6X9zNzKWj6DsOv+Rld84yV3YZwbdbd7NWzu7piRBowJfl0VoM1V3d+b9xyciEdM1exMqMjjzwCc+cahy16yi9DMG+eWuoiDZiSfLqqosBYZAXJxYv9Me+j/5rBwIFGSYkxeHCQ09UVI5IWlOTTVaXyvV5Z2Q8VJHv3/hWXD/2aHj0cW7fCtGkwZQq0jDbjQUQaNCX5dFVpiGPYjKKiIsrL+7JgwT/4+8NNGdFyKmtWeQwYkOxgRSRRVLsmXVUa4mhhaN58BuHwyfyUNcyiJ73Cb8N3pUStOSMiaUEt+XSWkYEL5fDkU0ZurvHFF325+uqv2Nzzz/TKeluTmEQaAbXk09gnn8Dw4TB7NvToARMnGkceuS94szWJSaSRUEs+DZWXw913++usLlwIDzwA8+fDkUcGB2jkjEijoZZ8mlm2DC66CJYsgdNO82u9t22b7KhEJFnUkk9Blcez18Q338CoUZCXB+vW+Uu0Tp+uBC/S2CnJpxjP834Yz15QUIDneXt8TmEhdOwIt98OF1wAJSVw9tnqjRERJfmUEw6Hg/HswRJ81ayWtXUrDB0KJ57oT3CdMwcmTYIWLeoxYBFJaUryqSKoMxPKzq52CT7wS9E89xx06ACPPuov6rFyJZx0UhLiFpGUphuvqSCiprvl51M4dy7hzZt/tAQf+DXDLrnE72/v3BlmzIAuXZIUt4ikvJha8mZ2lpmtNjPPzPIq7RtlZmvN7H0z6xdbmGmuUp2ZjM2bycnJ2S3Bex489BDk5vrj3u+8E95+WwleRKoXa3fNKuBM4M3IjWaWCwwCjgROBcaZWWaM50pfe1hKr6QEjj8eLr4Yunf3u2auucY/XESkOrEu5F0C/KhLARgIPOOc+w74xMzWAt2BhbGcL21VsZTezp3+iJlbboGmTf3+9wsv1KgZEam5RLUFWwOLIr4uDbb9iJkNA4YBtGvXLkHhNAAVs1ADCxf6I2dWr4ZBg+Dee3fbLSJSI3vsrjGzOWa2KsrHwOqeFmVb1Jk9zrkJzrk851xednZ2TeNOW19+CZddBj17wrZt8NJL8PTTSvAiUjd7bMk75/rW4XVLgci5lm2Az+rwOo3KK6/AiBH/HUFz663QrFmyoxKRhixR4+SnA4PMrImZHQwcDrydoHM1eBs3wjnn+LVmmjWDBQvg/vs8mu2ofn1WEZE9iXUI5RlmVgr0AF4xs1kAzrnVwBRgDfAqcIlzbleswaYb5+Cxx/xJTS+8ADfdBO++Cz2O/fH6rCIidWG1KYKVaHl5ea64uDjZYdSLjz7ya73Pnev3v0+c6Cd7AMrK/ARfXu6PkywtVae8iFTJzJY45/Ki7VNZg3pWXg533QVHH+1PZho3Dt58MyLBwx7HzYuI1JSm0ySY53mEw2FCoRDLlhkXXQRLl8KAAfDgg36D/UeqGDcvIlJbasknUEXZ4NatD6ddu6fp1s2xfj08+yxMnVpFgq+g1ZtEJA7Ukk+gcFkZ8+c3wfPepbT0UM499xse+HsTmpeHgRDRpxOIiMSPWvIJsmWTx8jOS/G81wCPjh2v4Kkn9qb5mRo1IyL1R0k+zpyDKVOgQwfHk2X9uJbb+CSjM8tmjcQ2bdqt2iTVLAgiIhIPSvJxtG6df0P17LOh7UEZFHcZzu1ZN9C+V1csJ0ejZkSk3qlPPg4qar2PHAm7dsHYsXDZZUZWxkQI37r7CBmNmhGReqQkH6PVq/1qkQsXwsknw/jxcMghFXszfjyJKSPKNhGRBFF3TR199x2MGQOdOzvef8/j8fu+YNarLiLBi4gkn5J8HRQV+eur3nQTnLX/HEq2tuSCK1pgfQo0YkZEUoqSfC1s3+6XAO7VC77+GmZM3srkLb8iRNgfVlNUBGvWqHKkiKQMJfkamj7dX0T7oYf8RT1Wr4ZfnbO/P0oG/JuoP/0pdOqkMfAikjKU5Pfg88/hd7+DgQOheXP/Buu998K+++In9nnz4LPPYPlyv3m/a5fGwItIylCSr4JzMGmSXx1y2jT43/+FJUvg2GMrHZiRAa1awVFH+TWDNQZeRFJIrIuGnGVmq83MM7O8iO3tzewbM1sWfIyPPdT6s3Yt9O0LQ4b4JYGXL4frr4e9967mSRWVI0tL/da9xsCLSAqIdZz8KuBM4JXqPaIAAAjPSURBVB9R9n3knOsU4+vXD8+DcJjvm4e4+x5jzBg/oY8f74+Bz6jpn0KNgReRFBNTknfOlQBYQ2u1BkmdUMjvl+nThyULvuWifZ5i2deHc8YZcP/90Lp1sgMVEYlNIvvkDzazd83sDTPrncDz1I63+/qpOz4p45q3BtB9VxFlXzfl+Ue+4IUXlOBFJD3ssSVvZnOAllF2Xe+cm1bF0zYA7Zxzm82sKzDVzI50zm2P8vrDgGEA7dq1q3nkNRXZajfzHweVIGfP/wnDT8zmE3c1w2wid/SYyv5/fDn+MYiIJMkeW/LOub7OuaOifFSV4HHOfeec2xw8XgJ8BPyiimMnOOfynHN52dnZdf0+oqvUasfzIBRic7dTudCe4BTvVbKaZDLvdY9/bBjA/vNf1g1TEUkrCSlQZmbZwBbn3C4zOwQ4HPg4EeeqVkSrnaIi3MYwzxTmcPna6WzNhOuucfz1BmOffQzQDVMRST+xDqE8w8xKgR7AK2Y2K9h1PLDCzJYDzwF/cs5tiS3UOoio3/6fLqdz2pAQ554L7dsbS5YYt9xq7LNPvUclIlJvzKVQnZW8vDxXXFwc19fc9b3HuDu/YtRtzXDOuOUW+POfITMzrqcREUkaM1vinMuLti+t68mvWgUXXZTB4sU/o98pjvH/G6Z93oHqdxeRRiMtyxp8+y3ccINfDnjtWnjqCY+Z3xTQPv/nKh4mIo1K2iX5t97yC0HefDMMGgQlJXDeKWFsoRbQFpHGJ22S/LZtMGIEHH+835J/9VV48knIzkYLaItIo5UWffLFxX4p4M8/hyuvhP/5n6AUcIWK4mFaQFtEGpm0SPKHHAJHHglTp0K3blUcpOJhItIIpUWSb9ECXnst2VGIiKSetOmTFxGRH1OSFxFJY0ryIiJpTEleRCSNKcmLiKQxJXkRkTSmJC8iksaU5EVE0lhK1ZM3szDw7xhe4kBgU5zCiSfFVTuKq3YUV+2lamx1jesg51zU9VNTKsnHysyKqyqcn0yKq3YUV+0ortpL1dgSEZe6a0RE0piSvIhIGku3JD8h2QFUQXHVjuKqHcVVe6kaW9zjSqs+eRER2V26teRFRCSCkryISBprUEnezM4ys9Vm5plZXqV9o8xsrZm9b2b9qnj+wWa22Mw+NLN/mdneCYrzX2a2LPj41MyWVXHcp2a2MjiuOBGxVDrfGDNbHxFb/yqOOzW4jmvNbGQ9xPU3M3vPzFaY2Ytmtn8Vx9XL9drT929mTYKf8drg/dQ+UbFEnLOtmRWaWUnwO3B5lGMKzGxbxM/3hkTHFZy32p+L+f4eXK8VZtalHmL6ZcR1WGZm283sikrH1Nv1MrNJZrbRzFZFbGthZrODfDTbzJpX8dwLg2M+NLMLa31y51yD+QA6AL8E5gF5EdtzgeVAE+Bg4CMgM8rzpwCDgsfjgRH1EPNY4IYq9n0KHFiP128M8P/3cExmcP0OAfYOrmtuguM6BcgKHt8B3JGs61WT7x+4GBgfPB4E/KsefnatgC7B42bAB1HiKgBerq/3U01/LkB/YCZgwHHA4nqOLxP4HH/CUFKuF3A80AVYFbHtTmBk8HhktPc90AL4OPjcPHjcvDbnblAteedciXPu/Si7BgLPOOe+c859AqwFukceYGYGnAg8F2x6HDg9kfEG5/wd8HQizxNn3YG1zrmPnXM7gWfwr2/COOdec86VB18uAtok8nx7UJPvfyD++wf899NJwc86YZxzG5xzS4PHXwIlQOtEnjOOBgJPON8iYH8za1WP5z8J+Mg5F8ts+pg4594EtlTaHPk+qiof9QNmO+e2OOe2ArOBU2tz7gaV5KvRGlgX8XUpP/4FOAD4IiKZRDsm3noDZc65D6vY74DXzGyJmQ1LcCwVLg3+ZZ5Uxb+HNbmWiTQYv9UXTX1cr5p8/z8cE7yftuG/v+pF0D3UGVgcZXcPM1tuZjPN7Mh6CmlPP5dkv6cGUXVDKxnXq0KOc24D+H/EgVCUY2K+dim3kLeZzQFaRtl1vXNuWlVPi7Kt8tjQmhxTYzWM8xyqb8X3dM59ZmYhYLaZvRf8xa+z6uICHgJuxv++b8bvShpc+SWiPDfmcbY1uV5mdj1QDkyu4mXifr2ihRplW0LfS7VhZvsCzwNXOOe2V9q9FL9L4qvgfstU4PB6CGtPP5dkXq+9gQHAqCi7k3W9aiPma5dySd4517cOTysF2kZ83Qb4rNIxm/D/TcwKWl/RjqmxPcVpZlnAmUDXal7js+DzRjN7Eb+rIKakVdPrZ2YTgZej7KrJtYx7XMENpdOAk1zQGRnlNeJ+vaKoyfdfcUxp8HPejx//Kx53ZrYXfoKf7Jx7ofL+yKTvnJthZuPM7EDnXEILcdXg55KQ91QN/QpY6pwrq7wjWdcrQpmZtXLObQi6rzZGOaYU/95BhTb49yRrLF26a6YDg4JRDwfj/zV+O/KAIHEUAr8NNl0IVPWfQTz0Bd5zzpVG22lmTc2sWcVj/JuPq6IdGy+V+kHPqOJ87wCHmz8SaW/8f3WnJziuU4FrgQHOuR1VHFNf16sm3/90/PcP+O+n16v6wxQvQZ//I0CJc+7uKo5pWXFvwMy64/9+b05wXDX5uUwHLghG2RwHbKvopqgHVf43nYzrVUnk+6iqfDQLOMXMmgfdq6cE22quPu4sx+sDPzGVAt8BZcCsiH3X44+KeB/4VcT2GcDPg8eH4Cf/tcCzQJMExvoY8KdK234OzIiIZXnwsRq/2yLR1+9JYCWwIniDtaocV/B1f/zRGx/VU1xr8fsdlwUf4yvHVZ/XK9r3D/wP/h8hgH2C98/a4P10SD1co174/6aviLhO/YE/VbzPgEuDa7Mc/wZ2fj3EFfXnUikuAx4MrudKIkbGJTi2n+In7f0itiXleuH/odkAfB/ksCH493HmAh8Gn1sEx+YBD0c8d3DwXlsL/LG251ZZAxGRNJYu3TUiIhKFkryISBpTkhcRSWNK8iIiaUxJXkQkjSnJi4ikMSV5EZE09n8dJUn9W3RDugAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "weight: 1.9990227 bias: 2.9115517\n" + ] + } + ], + "source": [ + "from IPython import display\n", + "\n", + "step_size = 200\n", + "batch_size = 16\n", + "\n", + "for i in range(step_size):\n", + " data_x,data_y = get_data(batch_size)\n", + " grads = train_network(data_x,data_y) \n", + " optim(grads)\n", + " plot_model_and_datasets(net.weight.default_input, \n", + " net.bias.default_input, data_x, data_y)\n", + " display.clear_output(wait=True)\n", + "\n", + "output = net(eval_x)\n", + "loss_output = criterion(output, eval_label)\n", + "print(\"loss_value:\", loss_output.asnumpy())\n", + "plot_model_and_datasets(net.weight.default_input, net.bias.default_input, data_x,data_y)\n", + "print(\"weight:\", net.weight.default_input[0][0], \"bias:\", net.bias.default_input[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "可以看到最终得到的线性拟合的权重值非常接近目标函数权重weight=2、bias=3。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 总结" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "本次体验我们了解了线性拟合的算法原理,并在MindSpore框架下实现了相应的算法定义,了解了线性拟合这类的线性回归模型在MindSpore中的训练过程,并最终拟合出了一条接近目标函数的模型函数。另外有兴趣的可以调整数据集的生成区间从(-10,10)扩展到(-100,100),看看权重值是否更接近目标函数;调整学习率大小,看看拟合的效率是否有变化;当然也可以探索如何使用MindSpore拟合$f(x)=ax^2+bx+c$这类的二次函数或者更高阶的函数。" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}