提交 3920a1df 编写于 作者: L lvmingfu

fix linear regressions code in notebook

上级 b29483a6
...@@ -34,6 +34,44 @@ ...@@ -34,6 +34,44 @@
"5. 执行训练" "5. 执行训练"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 环境准备"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"系统:Ubuntu18.04\n",
"\n",
"MindSpore版本:GPU\n",
"\n",
"设置MindSpore运行配置"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from mindspore import context\n",
"\n",
"context.set_context(mode=context.PYNATIVE_MODE, device_target=\"GPU\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`PYNATIVE_MODE`:自定义调试模式。\n",
"\n",
"`device_target`:设置MindSpore的训练硬件为GPU。"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
...@@ -47,22 +85,18 @@ ...@@ -47,22 +85,18 @@
"source": [ "source": [
"### 定义数据集生成函数\n", "### 定义数据集生成函数\n",
"\n", "\n",
"`get_data`用于生成训练数据集和测试数据集。由于拟合的是线性数据,假定要拟合的目标函数为:$y=2x+3$,那么我们需要的训练数据集应随机分布于函数周边,这里采用了`y=2x+3+noise`的方式生成,其中`noise`为遵循标准正态分布规律的随机数值。" "`get_data`用于生成训练数据集和测试数据集。由于拟合的是线性数据,假定要拟合的目标函数为:$y=2x+3$,那么我们需要的训练数据集应随机分布于函数周边,这里采用了$y=2x+3+noise$的方式生成,其中`noise`为遵循标准正态分布规律的随机数值。"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import numpy as np\n", "import numpy as np\n",
"import mindspore as ms\n", "import mindspore as ms\n",
"from mindspore import Tensor\n", "from mindspore import Tensor\n",
"from mindspore import context\n",
"from mindspore.train import Model\n",
"\n",
"context.set_context(mode=context.PYNATIVE_MODE, device_target=\"GPU\")\n",
" \n", " \n",
"def get_data(num,w=2.0, b=3.0):\n", "def get_data(num,w=2.0, b=3.0):\n",
" np_x = np.ones([num, 1])\n", " np_x = np.ones([num, 1])\n",
...@@ -102,7 +136,7 @@ ...@@ -102,7 +136,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 3,
"metadata": { "metadata": {
"scrolled": true "scrolled": true
}, },
...@@ -157,7 +191,7 @@ ...@@ -157,7 +191,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -192,7 +226,7 @@ ...@@ -192,7 +226,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 5,
"metadata": { "metadata": {
"scrolled": true "scrolled": true
}, },
...@@ -296,7 +330,7 @@ ...@@ -296,7 +330,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -333,7 +367,7 @@ ...@@ -333,7 +367,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -469,12 +503,12 @@ ...@@ -469,12 +503,12 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"在MindSpore中的所有要编入计算图的类都需要继承`nn.Cell`算子MindSpore的梯度计算函数采用如下方式。" "在MindSpore中的所有要编入计算图的类都需要继承`nn.Cell`算子MindSpore的梯度计算函数采用如下方式。"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -512,12 +546,12 @@ ...@@ -512,12 +546,12 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"`nn.RMSProp`为完成权重更新的函数,更新方式大致为公式10,但是考虑的因素更多,具体信息请参考官网说明:<www.mindspore.cn/api/zh-CN/master/api/python/mindspore/mindspore.nn.html?highlight=rmsprop#mindspore.nn.RMSProp>" "`nn.RMSProp`为完成权重更新的函数,更新方式大致为公式10,但是考虑的因素更多,具体信息请参考[官网说明](www.mindspore.cn/api/zh-CN/master/api/python/mindspore/mindspore.nn.html?highlight=rmsprop#mindspore.nn.RMSProp)。"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -549,7 +583,7 @@ ...@@ -549,7 +583,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -561,6 +595,7 @@ ...@@ -561,6 +595,7 @@
" plt.scatter(x1,y1,color=\"red\",s=5)\n", " plt.scatter(x1,y1,color=\"red\",s=5)\n",
" plt.scatter(data_x.asnumpy(), data_y.asnumpy(), color=\"black\", s=5)\n", " plt.scatter(data_x.asnumpy(), data_y.asnumpy(), color=\"black\", s=5)\n",
" plt.plot(x, y, \"blue\")\n", " plt.plot(x, y, \"blue\")\n",
" plt.axis([-11, 11, -20, 25])\n",
" plt.show()\n", " plt.show()\n",
" time.sleep(0.02)" " time.sleep(0.02)"
] ]
...@@ -573,7 +608,7 @@ ...@@ -573,7 +608,7 @@
"\n", "\n",
"- `weight`:模型函数的权重,即$w$。\n", "- `weight`:模型函数的权重,即$w$。\n",
"\n", "\n",
"- `bias`:模型函数的权重,$b$。\n", "- `bias`:模型函数的权重,$b$。\n",
"\n", "\n",
"- `data_x`:训练数据的x值。\n", "- `data_x`:训练数据的x值。\n",
"\n", "\n",
...@@ -612,7 +647,7 @@ ...@@ -612,7 +647,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 11,
"metadata": { "metadata": {
"scrolled": true "scrolled": true
}, },
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册