提交 fff67a94 编写于 作者: J JiabinYang

test=develop, use parameters() to get parameters

上级 2e309b11
......@@ -483,6 +483,9 @@ class Embedding(layers.Layer):
dtype=self._dtype,
is_bias=False)
def parameters(self):
return [self._w]
def forward(self, input):
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
......
......@@ -75,6 +75,16 @@ class SimpleLSTMRNN(fluid.imperative.Layer):
self.hidden_array.append(pre_hidden)
self.cell_array.append(pre_cell)
def parameters(self):
parameters = list()
for param in self.weight_1_arr:
parameters.append(param)
for param in self.weight_2_arr:
parameters.append(param)
for bias in self.bias_arr:
parameters.append(bias)
return parameters
def forward(self, input_embedding, init_hidden=None, init_cell=None):
res = []
for index in range(self._num_steps):
......@@ -167,6 +177,12 @@ class PtbModel(fluid.imperative.Layer):
def _build_once(self, input, label, init_hidden, init_cell):
pass
def parameters(self):
parameters = self.simple_lstm_rnn.parameters() + [
self.softmax_weight, self.softmax_bias
] + self.embedding.parameters()
return parameters
def forward(self, input, label, init_hidden, init_cell):
init_h = fluid.layers.reshape(
......@@ -246,13 +262,11 @@ class TestImperativePtbRnn(unittest.TestCase):
dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
init_cell)
if i == 0:
for param in fluid.default_main_program().global_block(
).all_parameters():
for param in ptb_model.parameters():
dy_param_init[param.name] = param._numpy()
dy_loss._backward()
sgd.minimize(dy_loss)
for param in fluid.default_main_program().global_block(
).all_parameters():
for param in ptb_model.parameters():
dy_param_updated[param.name] = param._numpy()
# print("dy_loss is {}".format(dy_loss._numpy()))
# print("last_hidden is {}".format(last_hidden._numpy()))
......@@ -284,8 +298,7 @@ class TestImperativePtbRnn(unittest.TestCase):
static_param_updated = dict()
static_param_init = dict()
static_param_name_list = list()
for param in fluid.default_startup_program().global_block(
).all_parameters():
for param in ptb_model.parameters():
static_param_name_list.append(param.name)
out = exe.run(framework.default_startup_program(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册