提交 dbd4d058 编写于 作者: M minqiyang

Add static implementation and fix fc layer

上级 315b133e
......@@ -138,6 +138,13 @@ PYBIND11_MODULE(core, m) {
py::return_value_policy::reference)
.def("value", [](const imperative::VarBase &self) { return self.var_; },
py::return_value_policy::reference)
.def("wait_device",
[](const imperative::VarBase &self) {
platform::DeviceContext *dev_ctx =
platform::DeviceContextPool::Instance().Get(
self.var_->Get<framework::LoDTensor>().place());
dev_ctx->Wait();
})
.def_property(
"desc",
[](const imperative::VarBase &self) { return self.var_desc_; },
......
......@@ -384,6 +384,7 @@ class Variable(object):
self._ivar.stop_gradient = stop_gradient
def _numpy(self):
self._ivar.wait_device()
tensor = self._ivar.value().get_tensor()
return np.array(tensor)
......
......@@ -45,9 +45,9 @@ def guard(device=0):
def to_variable(value, block=None):
assert enabled(), "to_variable could only be called in imperative mode"
if isinstance(value, np.ndarray):
assert enabled(), "to_variable could only be called in imperative mode"
if not block:
block = framework.default_main_program().current_block()
py_var = framework.Variable(
......
......@@ -239,6 +239,17 @@ class FC(layers.Layer):
shape=param_shape,
dtype=self._dtype,
is_bias=False)
print("create param: ", self._w.name, self._w.stop_gradient)
if self._helper.bias_attr:
size = list([self._size])
self._b = self._helper.create_parameter(
attr=self._helper.bias_attr,
shape=size,
dtype=self._dtype,
is_bias=True)
else:
self._b = None
def forward(self, input):
tmp = self._helper.create_variable_for_type_inference(self._dtype)
......@@ -259,8 +270,17 @@ class FC(layers.Layer):
outputs={"Out": pre_bias},
attrs={"use_mkldnn": False})
pre_activation = self._helper.append_bias_op(
pre_bias, dim_start=self._num_flatten_dims)
if self._b:
pre_activation = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type='elementwise_add',
inputs={'X': [pre_bias],
'Y': [self._b]},
outputs={'Out': [pre_activation]},
attrs={'axis': self._num_flatten_dims})
else:
pre_activation = pre_bias
return self._helper.append_activation(pre_activation)
......
......@@ -387,6 +387,9 @@ class Optimizer(object):
params_grads = []
for param in parameters:
if param.stop_gradient:
print("parameter:", param.name, "stop gradient, skip it")
continue
# create gradient variable
grad_var = Variable(
block=loss.block,
......
......@@ -31,11 +31,11 @@ train_parameters = {
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"batch_size": 1,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
},
"batch_size": 256,
"batch_size": 1,
"lr": 0.1,
"total_images": 1281164,
}
......@@ -201,6 +201,7 @@ class TestImperativeResnet(unittest.TestCase):
def test_resnet_gpu_float32(self):
seed = 90
batch_size = train_parameters["batch_size"]
with fluid.imperative.guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
......@@ -208,17 +209,21 @@ class TestImperativeResnet(unittest.TestCase):
resnet = ResNet()
optimizer = optimizer_setting(train_parameters)
train_reader = paddle.batch(
paddle.dataset.flowers.train(), batch_size=256)
paddle.dataset.flowers.train(), batch_size=batch_size)
dy_param_init_value = {}
for param in fluid.default_main_program().global_block(
).all_parameters():
dy_param_init_value[param.name] = param._numpy()
for batch_id, data in enumerate(train_reader()):
if batch_id >= 2:
if batch_id >= 1:
break
x_data = np.array(
[x[0].reshape(3, 224, 224) for x in data]).astype('float32')
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
256, 1)
batch_size, 1)
img = to_variable(x_data)
label = to_variable(y_data)
......@@ -232,74 +237,81 @@ class TestImperativeResnet(unittest.TestCase):
if batch_id == 0:
for param in fluid.default_main_program().global_block(
).all_parameters():
dy_param_init_value[param.name] = param._numpy()
if param.name not in dy_param_init_value:
dy_param_init_value[param.name] = param._numpy()
avg_loss._backward()
optimizer.minimize(avg_loss)
dy_param_value = {}
for param in fluid.default_main_program().global_block(
).all_parameters():
dy_param_value[param.name] = param._numpy()
# with new_program_scope():
# fluid.default_startup_program().random_seed = seed
# fluid.default_main_program().random_seed = seed
# exe = fluid.Executor(fluid.CPUPlace())
# # mnist = Conv2D(1, 20, 5)
# mnist = MNIST()
# sgd = SGDOptimizer(learning_rate=1e-3)
# train_reader = paddle.batch(
# paddle.dataset.mnist.train(), batch_size=128)
# img = fluid.layers.data(
# name='pixel', shape=[1, 28, 28], dtype='float32')
# label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# cost = mnist(img)
# loss = fluid.layers.reduce_mean(cost)
# sgd.minimize(loss)
# # initialize params and fetch them
# static_param_init_value = {}
# static_param_name_list = []
# for param in fluid.default_startup_program().global_block(
# ).all_parameters():
# static_param_name_list.append(param.name)
# out = exe.run(fluid.default_startup_program(),
# fetch_list=static_param_name_list)
# for i in range(len(static_param_name_list)):
# static_param_init_value[static_param_name_list[i]] = out[i]
# for batch_id, data in enumerate(train_reader()):
# if batch_id >= 2:
# break
# x_data = np.array(
# [x[0].reshape(1, 28, 28) for x in data]).astype('float32')
# y_data = np.array([x[1] for x in data]).astype('int64').reshape(
# [128, 1])
# fetch_list = [loss.name]
# fetch_list.extend(static_param_name_list)
# out = exe.run(fluid.default_main_program(),
# feed={"pixel": x_data,
# "label": y_data},
# fetch_list=fetch_list)
# static_param_value = {}
# static_out = out[0]
# for i in range(1, len(out)):
# static_param_value[static_param_name_list[i - 1]] = out[i]
# for key, value in six.iteritems(static_param_init_value):
# self.assertTrue(
# np.allclose(value.all(), dy_param_init_value[key].all()))
# self.assertTrue(np.allclose(static_out.all(), dy_out.all()))
# for key, value in six.iteritems(static_param_value):
# self.assertTrue(np.allclose(value.all(), dy_param_value[key].all()))
with new_program_scope():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
exe = fluid.Executor(fluid.CUDAPlace(0))
resnet = ResNet()
optimizer = optimizer_setting(train_parameters)
train_reader = paddle.batch(
paddle.dataset.flowers.train(), batch_size=batch_size)
img = fluid.layers.data(
name='pixel', shape=[3, 224, 224], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
out = resnet(img)
loss = fluid.layers.cross_entropy(input=out, label=label)
avg_loss = fluid.layers.mean(x=loss)
optimizer.minimize(avg_loss)
# initialize params and fetch them
static_param_init_value = {}
static_param_name_list = []
for param in fluid.default_startup_program().global_block(
).all_parameters():
static_param_name_list.append(param.name)
out = exe.run(fluid.default_startup_program(),
fetch_list=static_param_name_list)
for i in range(len(static_param_name_list)):
static_param_init_value[static_param_name_list[i]] = out[i]
for batch_id, data in enumerate(train_reader()):
if batch_id >= 1:
break
x_data = np.array(
[x[0].reshape(3, 224, 224) for x in data]).astype('float32')
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
[batch_size, 1])
fetch_list = [loss.name]
fetch_list.extend(static_param_name_list)
out = exe.run(fluid.default_main_program(),
feed={"pixel": x_data,
"label": y_data},
fetch_list=fetch_list)
static_param_value = {}
static_out = out[0]
for i in range(1, len(out)):
static_param_value[static_param_name_list[i - 1]] = out[i]
self.assertTrue(np.allclose(static_out.all(), dy_out.all()))
for key, value in six.iteritems(static_param_init_value):
self.assertTrue(
np.allclose(value.all(), dy_param_init_value[key].all()))
for key, value in six.iteritems(static_param_value):
if not np.allclose(value.all(), dy_param_value[key].all()):
print(key)
print(value, dy_param_value[key])
self.assertTrue(np.allclose(value.all(), dy_param_value[key].all()))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册