未验证 提交 a7c242a6 编写于 作者: J jzhang533 提交者: GitHub

cnn based image classification added, minor changes to dygraph example (#877)

上级 5952d7bc
因为 它太大了无法显示 source diff 。你可以改为 查看blob
......@@ -30,7 +30,7 @@
"output_type": "stream",
"text": [
"0.0.0\n",
"edf5f3173a25ae2230e9619ab5426317b4bd7cde\n"
"7f2aa2db3c69cb9ebb8bae9e19280e75f964e1d0\n"
]
}
],
......@@ -62,16 +62,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[[-1.4396907 -0.59741247]\n",
" [ 1.4717171 -0.06998838]\n",
" [-1.2790705 0.4278928 ]\n",
" [ 1.1862146 -1.895377 ]]\n",
"[[ 0.40741017 0.2083312 ]\n",
" [-1.7567089 0.72117436]\n",
" [ 0.8870686 -1.1389219 ]\n",
" [ 1.1233491 0.34348443]]\n",
"[1. 2.]\n",
"[[-0.4396907 1.4025875 ]\n",
" [ 2.4717171 1.9300116 ]\n",
" [-0.2790705 2.4278927 ]\n",
" [ 2.1862144 0.10462296]]\n",
"[-2.6345158 1.3317404 -0.4232849 -2.6045394]\n"
"[[ 1.4074101 2.208331 ]\n",
" [-0.75670886 2.7211742 ]\n",
" [ 1.8870686 0.86107814]\n",
" [ 2.1233492 2.3434844 ]]\n",
"[ 0.8240726 -0.31436014 -1.3907751 1.810318 ]\n"
]
}
],
......@@ -108,13 +108,13 @@
"output_type": "stream",
"text": [
"0 +> [5 6 7]\n",
"1 -> [-3 -3 -3]\n",
"2 -> [-3 -1 3]\n",
"3 +> [ 5 13 33]\n",
"4 -> [-3 11 75]\n",
"1 +> [5 7 9]\n",
"2 +> [ 5 9 15]\n",
"3 -> [-3 3 21]\n",
"4 +> [ 5 21 87]\n",
"5 +> [ 5 37 249]\n",
"6 -> [ -3 59 723]\n",
"7 -> [ -3 123 2181]\n",
"7 +> [ 5 133 2193]\n",
"8 -> [ -3 251 6555]\n",
"9 -> [ -3 507 19677]\n"
]
......@@ -179,21 +179,21 @@
"name": "stdout",
"output_type": "stream",
"text": [
"0 [87.28088]\n",
"200 [57.775795]\n",
"400 [42.70884]\n",
"600 [45.509155]\n",
"800 [29.966158]\n",
"1000 [11.885025]\n",
"1200 [16.888378]\n",
"1400 [3.5780585]\n",
"1600 [5.3149533]\n",
"1800 [4.501356]\n",
"2000 [3.022315]\n",
"2200 [1.7214009]\n",
"2400 [0.3694626]\n",
"2600 [0.31249344]\n",
"2800 [0.1450614]\n"
"0 [2.0915627]\n",
"200 [0.67530334]\n",
"400 [0.52042854]\n",
"600 [0.28010666]\n",
"800 [0.09739777]\n",
"1000 [0.09307177]\n",
"1200 [0.04252927]\n",
"1400 [0.03095707]\n",
"1600 [0.03022156]\n",
"1800 [0.01616007]\n",
"2000 [0.01069116]\n",
"2200 [0.0055158]\n",
"2400 [0.00195092]\n",
"2600 [0.00101116]\n",
"2800 [0.00192219]\n"
]
}
],
......@@ -205,9 +205,9 @@
"\n",
"model = MyModel(input_size, hidden_size)\n",
"\n",
"loss_fn = paddle.nn.MSELoss(reduction='sum')\n",
"optimizer = paddle.optimizer.SGD(learning_rate=0.0001, \n",
" parameter_list=model.parameters())\n",
"loss_fn = paddle.nn.MSELoss(reduction='mean')\n",
"optimizer = paddle.optimizer.SGD(learning_rate=0.01, \n",
" parameters=model.parameters())\n",
"\n",
"for t in range(200 * (total_data // batch_size)):\n",
" idx = np.random.choice(total_data, batch_size, replace=False)\n",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册