未验证 提交 a7c242a6 编写于 作者: J jzhang533 提交者: GitHub

cnn based image classification added, minor changes to dygraph example (#877)

上级 5952d7bc
因为 它太大了无法显示 source diff 。你可以改为 查看blob
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"0.0.0\n", "0.0.0\n",
"edf5f3173a25ae2230e9619ab5426317b4bd7cde\n" "7f2aa2db3c69cb9ebb8bae9e19280e75f964e1d0\n"
] ]
} }
], ],
...@@ -62,16 +62,16 @@ ...@@ -62,16 +62,16 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"[[-1.4396907 -0.59741247]\n", "[[ 0.40741017 0.2083312 ]\n",
" [ 1.4717171 -0.06998838]\n", " [-1.7567089 0.72117436]\n",
" [-1.2790705 0.4278928 ]\n", " [ 0.8870686 -1.1389219 ]\n",
" [ 1.1862146 -1.895377 ]]\n", " [ 1.1233491 0.34348443]]\n",
"[1. 2.]\n", "[1. 2.]\n",
"[[-0.4396907 1.4025875 ]\n", "[[ 1.4074101 2.208331 ]\n",
" [ 2.4717171 1.9300116 ]\n", " [-0.75670886 2.7211742 ]\n",
" [-0.2790705 2.4278927 ]\n", " [ 1.8870686 0.86107814]\n",
" [ 2.1862144 0.10462296]]\n", " [ 2.1233492 2.3434844 ]]\n",
"[-2.6345158 1.3317404 -0.4232849 -2.6045394]\n" "[ 0.8240726 -0.31436014 -1.3907751 1.810318 ]\n"
] ]
} }
], ],
...@@ -108,13 +108,13 @@ ...@@ -108,13 +108,13 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"0 +> [5 6 7]\n", "0 +> [5 6 7]\n",
"1 -> [-3 -3 -3]\n", "1 +> [5 7 9]\n",
"2 -> [-3 -1 3]\n", "2 +> [ 5 9 15]\n",
"3 +> [ 5 13 33]\n", "3 -> [-3 3 21]\n",
"4 -> [-3 11 75]\n", "4 +> [ 5 21 87]\n",
"5 +> [ 5 37 249]\n", "5 +> [ 5 37 249]\n",
"6 -> [ -3 59 723]\n", "6 -> [ -3 59 723]\n",
"7 -> [ -3 123 2181]\n", "7 +> [ 5 133 2193]\n",
"8 -> [ -3 251 6555]\n", "8 -> [ -3 251 6555]\n",
"9 -> [ -3 507 19677]\n" "9 -> [ -3 507 19677]\n"
] ]
...@@ -179,21 +179,21 @@ ...@@ -179,21 +179,21 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"0 [87.28088]\n", "0 [2.0915627]\n",
"200 [57.775795]\n", "200 [0.67530334]\n",
"400 [42.70884]\n", "400 [0.52042854]\n",
"600 [45.509155]\n", "600 [0.28010666]\n",
"800 [29.966158]\n", "800 [0.09739777]\n",
"1000 [11.885025]\n", "1000 [0.09307177]\n",
"1200 [16.888378]\n", "1200 [0.04252927]\n",
"1400 [3.5780585]\n", "1400 [0.03095707]\n",
"1600 [5.3149533]\n", "1600 [0.03022156]\n",
"1800 [4.501356]\n", "1800 [0.01616007]\n",
"2000 [3.022315]\n", "2000 [0.01069116]\n",
"2200 [1.7214009]\n", "2200 [0.0055158]\n",
"2400 [0.3694626]\n", "2400 [0.00195092]\n",
"2600 [0.31249344]\n", "2600 [0.00101116]\n",
"2800 [0.1450614]\n" "2800 [0.00192219]\n"
] ]
} }
], ],
...@@ -205,9 +205,9 @@ ...@@ -205,9 +205,9 @@
"\n", "\n",
"model = MyModel(input_size, hidden_size)\n", "model = MyModel(input_size, hidden_size)\n",
"\n", "\n",
"loss_fn = paddle.nn.MSELoss(reduction='sum')\n", "loss_fn = paddle.nn.MSELoss(reduction='mean')\n",
"optimizer = paddle.optimizer.SGD(learning_rate=0.0001, \n", "optimizer = paddle.optimizer.SGD(learning_rate=0.01, \n",
" parameter_list=model.parameters())\n", " parameters=model.parameters())\n",
"\n", "\n",
"for t in range(200 * (total_data // batch_size)):\n", "for t in range(200 * (total_data // batch_size)):\n",
" idx = np.random.choice(total_data, batch_size, replace=False)\n", " idx = np.random.choice(total_data, batch_size, replace=False)\n",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册