提交 0f6ef8ed 编写于 作者: M minqiyang

Add MNIST

test=develop
上级 a7966e67
......@@ -99,7 +99,7 @@ class Conv2D(layers.PyLayer):
self._bias_param = self._helper.create_parameter(
attr=self._helper.bias_attr,
shape=[num_filter_channels],
shape=[num_filters],
dtype=self._dtype,
is_bias=True)
......
......@@ -29,8 +29,8 @@ from test_imperative_base import new_program_scope
class SimpleImgConvPool(fluid.imperative.PyLayer):
def __init__(self,
num_channels,
filter_size,
num_filters,
filter_size,
pool_size,
pool_stride,
pool_padding=0,
......@@ -77,10 +77,10 @@ class MNIST(fluid.imperative.PyLayer):
super(MNIST, self).__init__(param_attr=param_attr, bias_attr=bias_attr)
self._simple_img_conv_pool_1 = SimpleImgConvPool(
1, 5, 20, 2, 2, act="relu")
1, 20, 5, 2, 2, act="relu")
self._simple_img_conv_pool_2 = SimpleImgConvPool(
20, 5, 50, 2, 2, act="relu")
20, 50, 5, 2, 2, act="relu")
pool_2_shape = 50 * 8 * 8
SIZE = 10
......@@ -106,18 +106,15 @@ class TestImperativeMnist(unittest.TestCase):
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
mnist = Conv2D(1, 20, 5)
# mnist = Conv2D(1, 20, 5)
mnist = MNIST()
sgd = SGDOptimizer(learning_rate=1e-3)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128)
dy_param_value = {}
for param in fluid.default_main_program().global_block(
).all_parameters():
dy_param_value[param.name] = param._numpy()
dy_param_init_value = {}
for batch_id, data in enumerate(train_reader()):
if batch_id >= 1:
if batch_id >= 2:
break
x_data = np.array(
......@@ -133,9 +130,17 @@ class TestImperativeMnist(unittest.TestCase):
loss = fluid.layers.reduce_mean(cost)
dy_out = loss._numpy()
if batch_id == 0:
for param in fluid.default_main_program().global_block(
).all_parameters():
dy_param_init_value[param.name] = param._numpy()
loss._backward()
sgd.minimize(loss)
dy_filter_param = mnist._filter_param._numpy()
dy_param_value = {}
for param in fluid.default_main_program().global_block(
).all_parameters():
dy_param_value[param.name] = param._numpy()
with new_program_scope():
fluid.default_startup_program().random_seed = seed
......@@ -143,7 +148,8 @@ class TestImperativeMnist(unittest.TestCase):
exe = fluid.Executor(fluid.CPUPlace())
mnist = Conv2D(1, 20, 5)
# mnist = Conv2D(1, 20, 5)
mnist = MNIST()
sgd = SGDOptimizer(learning_rate=1e-3)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128)
......@@ -156,7 +162,7 @@ class TestImperativeMnist(unittest.TestCase):
sgd.minimize(loss)
# initialize params and fetch them
static_param_value = {}
static_param_init_value = {}
static_param_name_list = []
for param in fluid.default_startup_program().global_block(
).all_parameters():
......@@ -166,27 +172,35 @@ class TestImperativeMnist(unittest.TestCase):
fetch_list=static_param_name_list)
for i in range(len(static_param_name_list)):
static_param_value[static_param_name_list[i]] = out[i]
static_param_init_value[static_param_name_list[i]] = out[i]
for batch_id, data in enumerate(train_reader()):
if batch_id >= 1:
if batch_id >= 2:
break
x_data = np.array(
[x[0].reshape(1, 28, 28) for x in data]).astype('float32')
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
[128, 1])
static_out, static_filter_param = exe.run(
fluid.default_main_program(),
fetch_list = [loss.name]
fetch_list.extend(static_param_name_list)
out = exe.run(fluid.default_main_program(),
feed={"pixel": x_data,
"label": y_data},
fetch_list=[loss.name, mnist._filter_param.name])
fetch_list=fetch_list)
static_param_value = {}
static_out = out[0]
for i in range(1, len(out)):
static_param_value[static_param_name_list[i - 1]] = out[i]
for key, value in six.iteritems(static_param_init_value):
self.assertTrue(
np.allclose(value.all(), dy_param_init_value[key].all()))
self.assertTrue(np.allclose(static_out.all(), dy_out.all()))
for key, value in six.iteritems(static_param_value):
self.assertTrue(np.allclose(value.all(), dy_param_value[key].all()))
self.assertTrue(np.allclose(static_out.all(), dy_out.all()))
self.assertTrue(
np.allclose(static_filter_param.all(), dy_filter_param.all()))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册