diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 51bbac6d2a1cf2bd64f3e1f2d420e104569273c8..bc39d11ba00a6a7c386162a1f9201c6f992c8692 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -31,6 +31,7 @@ void CreateGradOp(const framework::OpDesc& op_desc, framework::OpInfoMap::Instance() .Get(op_desc.Type()) .GradOpMaker()(op_desc, no_grad_set, grad_to_var, grad_sub_block); + for (auto& desc : descs) { grad_op_descs->emplace_back(desc.release()); } diff --git a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h index 216f416de0d1003b944337ee98fb4e6a22c66fc5..2da565f2ae15a50a207173b10d4c350456086582 100644 --- a/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h +++ b/paddle/fluid/inference/analysis/passes/memory_optimize_pass.h @@ -13,7 +13,9 @@ // limitations under the License. #pragma once - +#include +#include +#include #include "paddle/fluid/inference/analysis/analysis_pass.h" #include "paddle/fluid/platform/port.h" diff --git a/paddle/fluid/inference/utils/CMakeLists.txt b/paddle/fluid/inference/utils/CMakeLists.txt index c43eaf7f9849ee4a88ed95bdb8b6966da8760435..a7b239731b9a2e876c16d9ff84dfb8ac3df7b82e 100644 --- a/paddle/fluid/inference/utils/CMakeLists.txt +++ b/paddle/fluid/inference/utils/CMakeLists.txt @@ -1,4 +1,4 @@ cc_library(benchmark SRCS benchmark.cc DEPS enforce) cc_test(test_benchmark SRCS benchmark_tester.cc DEPS benchmark) -cc_binary(visualizer SRCS visualizer.cc DEPS analysis - paddle_pass_builder ir_pass_manager pass graph_viz_pass analysis_passes) +#cc_binary(visualizer SRCS visualizer.cc DEPS analysis +# paddle_pass_builder ir_pass_manager pass graph_viz_pass analysis_passes) diff --git a/python/paddle/fluid/imperative/nn.py b/python/paddle/fluid/imperative/nn.py index dc90603c3716bef347a5212a3c6fb66522e49c14..6c5961cc63d1c140e0a6f33aac054acdbbe8e8e0 100644 --- a/python/paddle/fluid/imperative/nn.py +++ b/python/paddle/fluid/imperative/nn.py @@ -22,13 +22,7 @@ from . import layers from ..framework import Variable, OpProtoHolder from ..param_attr import ParamAttr from ..initializer import Normal, Constant - -__all__ = [ - 'Conv2D', - 'Pool2D', - 'FC', - 'BatchNorm', -] +__all__ = ['Conv2D', 'Pool2D', 'FC', 'BatchNorm', 'Embedding'] class Conv2D(layers.Layer): @@ -414,3 +408,91 @@ class BatchNorm(layers.Layer): # Currently, we don't support inplace in imperative mode return self._helper.append_activation(batch_norm_out) + + +class Embedding(layers.Layer): + """ + **Embedding Layer** + + This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in + a lookup table. The result of this lookup is the embedding of each ID in the + :attr:`input`. + + All the input variables are passed in as local variables to the LayerHelper + constructor. + + Args: + size(tuple|list): The shape of the look up table parameter. It should + have two elements which indicate the size of the dictionary of + embeddings and the size of each embedding vector respectively. + is_sparse(bool): The flag indicating whether to use sparse update. + is_distributed(bool): Whether to run lookup table from remote parameter server. + padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup. + Otherwise the given :attr:`padding_idx` indicates padding the output + with zeros whenever lookup encounters it in :attr:`input`. If + :math:`padding_idx < 0`, the :attr:`padding_idx` to use in lookup is + :math:`size[0] + dim`. + param_attr(ParamAttr): Parameters for this layer + dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32, float_16, int etc + + Returns: + Variable: The tensor variable storing the embeddings of the \ + supplied inputs. + + Examples: + .. code-block:: python + + dict_size = len(dataset.ids) + input = fluid.layers.data(name='ids', shape=[32, 32], dtype='float32') + embedding = fluid.imperative.Embedding(size=[dict_size, 16]) + fc = embedding(input) + """ + + def __init__(self, + size, + is_sparse=False, + is_distributed=False, + padding_idx=None, + param_attr=None, + dtype='float32'): + + super(Embedding, self).__init__() + self._size = size + self._is_sparse = is_sparse + self._is_distributed = is_distributed + + self._padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( + size[0] + padding_idx) + + self._param_attr = param_attr + self._dtype = dtype + self._remote_prefetch = self._is_sparse and (not self._is_distributed) + if self._remote_prefetch: + assert self._is_sparse is True and self._is_distributed is False + + from ..layer_helper import LayerHelper + self._helper = LayerHelper('embedding', param_attr=param_attr) + self._w = self._helper.create_parameter( + attr=self._param_attr, + shape=self._size, + dtype=self._dtype, + is_bias=False) + + def parameters(self): + return [self._w] + + def forward(self, input): + out = self._helper.create_variable_for_type_inference(self._dtype) + self._helper.append_op( + type='lookup_table', + inputs={'Ids': input, + 'W': self._w}, + outputs={'Out': out}, + attrs={ + 'is_sparse': self._is_sparse, + 'is_distributed': self._is_distributed, + 'remote_prefetch': self._remote_prefetch, + 'padding_idx': self._padding_idx + }) + + return out diff --git a/python/paddle/fluid/tests/unittests/test_imperative.py b/python/paddle/fluid/tests/unittests/test_imperative.py index adf35c851bf05011223e483e472900a3d415e2ee..baaddf9f2e5b123300f1d083b33ea644665348fd 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative.py +++ b/python/paddle/fluid/tests/unittests/test_imperative.py @@ -66,6 +66,128 @@ class MLP(fluid.imperative.Layer): return x +class SimpleRNNCell(fluid.imperative.Layer): + def __init__(self, step_input_size, hidden_size, output_size, param_attr): + super(SimpleRNNCell, self).__init__() + self.step_input_size = step_input_size + self.hidden_size = hidden_size + self.output_size = output_size + self._dype = core.VarDesc.VarType.FP32 + from paddle.fluid.layer_helper import LayerHelper + self._helper = LayerHelper( + 'SimpleRNNCell', act="tanh", param_attr=param_attr) + + def _build_once(self, inputs, pre_hidden): + i2h_param_shape = [self.step_input_size, self.hidden_size] + h2h_param_shape = [self.hidden_size, self.hidden_size] + h2o_param_shape = [self.output_size, self.hidden_size] + self._i2h_w = self._helper.create_parameter( + attr=self._helper.param_attr, + shape=i2h_param_shape, + dtype=self._dtype, + is_bias=False) + self._h2h_w = self._helper.create_parameter( + attr=self._helper.param_attr, + shape=h2h_param_shape, + dtype=self._dtype, + is_bias=False) + self._h2o_w = self._helper.create_parameter( + attr=self._helper.param_attr, + shape=h2o_param_shape, + dtype=self._dtype, + is_bias=False) + + def forward(self, input, pre_hidden): + + tmp_i2h = self._helper.create_variable_for_type_inference(self._dtype) + tmp_h2h = self._helper.create_variable_for_type_inference(self._dtype) + hidden = self._helper.create_variable_for_type_inference(self._dype) + out = self._helper.create_variable_for_type_inference(self._dype) + softmax_out = self._helper.create_variable_for_type_inference( + self._dtype) + reduce_out = self._helper.create_variable_for_type_inference( + self._dtype) + self._helper.append_op( + type="mul", + inputs={"X": input, + "Y": self._i2h_w}, + outputs={"Out": tmp_i2h}, + attrs={"x_num_col_dims": 1, + "y_num_col_dims": 1}) + + self._helper.append_op( + type="mul", + inputs={"X": pre_hidden, + "Y": self._h2h_w}, + outputs={"Out": tmp_h2h}, + attrs={"x_num_col_dims": 1, + "y_num_col_dims": 1}) + + self._helper.append_op( + type="elementwise_add", + inputs={'X': tmp_h2h, + 'Y': tmp_i2h}, + outputs={'Out': hidden}, + attrs={'axis': -1, + 'use_mkldnn': False}) + hidden = self._helper.append_activation(hidden) + + self._helper.append_op( + type="mul", + inputs={"X": hidden, + "Y": self._h2o_w}, + outputs={"Out": out}, + attrs={"x_num_col_dims": 1, + "y_num_col_dims": 1}) + + self._helper.append_op( + type="softmax", + inputs={"X": out}, + outputs={"Out": softmax_out}, + attrs={"use_cudnn": False}) + + self._helper.append_op( + type='reduce_sum', + inputs={'X': softmax_out}, + outputs={'Out': reduce_out}, + attrs={'dim': None, + 'keep_dim': False, + 'reduce_all': True}) + + return reduce_out, hidden + + +class SimpleRNN(fluid.imperative.Layer): + def __init__(self): + super(SimpleRNN, self).__init__() + self.seq_len = 4 + self._cell = SimpleRNNCell( + 3, + 3, + 3, + fluid.ParamAttr(initializer=fluid.initializer.Constant(value=0.1))) + + def forward(self, inputs): + outs = list() + pre_hiddens = list() + + init_hidden = fluid.layers.tensor.create_parameter( + attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.1)), + shape=[1, 3], + dtype='float32', + is_bias=False) + pre_hidden = init_hidden + for i in range(self.seq_len): + input = fluid.layers.slice( + inputs, axes=[1], starts=[i], ends=[i + 1]) + input = fluid.layers.reshape(input, shape=[1, 3]) + out_softmax, pre_hidden = self._cell(input, pre_hidden) + outs.append(out_softmax) + + return outs, pre_hiddens + + class TestImperative(unittest.TestCase): def test_sum_op(self): x = np.ones([2, 2], np.float32) @@ -211,6 +333,41 @@ class TestImperative(unittest.TestCase): self.assertTrue(np.allclose(dy_out, static_out)) self.assertTrue(np.allclose(dy_grad, static_grad)) + def test_rnn(self): + np_inp = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], + [10.0, 11.0, 12.0]]) + np_inp = np_inp.reshape((1, 4, 3)) + np_inp = np_inp.astype(np.float32) + with fluid.imperative.guard(): + var_inp = fluid.imperative.base.to_variable(np_inp) + var_inp = fluid.layers.reshape(var_inp, shape=[1, 4, 3]) + simple_rnn = SimpleRNN() + outs, pre_hiddens = simple_rnn.forward(var_inp) + dy_out = outs[3]._numpy() + outs[3]._backward() + dy_grad_h2o = simple_rnn._cell._h2o_w._gradient() + dy_grad_h2h = simple_rnn._cell._h2h_w._gradient() + dy_grad_i2h = simple_rnn._cell._i2h_w._gradient() + + with new_program_scope(): + inp = fluid.layers.data( + name="inp", shape=[1, 4, 3], append_batch_size=False) + simple_rnn = SimpleRNN() + outs, pre_hiddens = simple_rnn(inp) + param_grads = fluid.backward.append_backward(outs[3]) + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + static_out, static_grad_h2o, static_grad_h2h, static_grad_i2h = exe.run( + feed={inp.name: np_inp}, + fetch_list=[ + outs[3].name, param_grads[0][1].name, + param_grads[1][1].name, param_grads[2][1].name + ]) + self.assertTrue(np.allclose(dy_out, static_out)) + self.assertTrue(np.allclose(dy_grad_h2o, static_grad_h2o)) + self.assertTrue(np.allclose(dy_grad_h2h, static_grad_h2h)) + self.assertTrue(np.allclose(dy_grad_i2h, static_grad_i2h)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5877e91f92e642e69265104c6728cd9bd41c41cd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_ptb_rnn.py @@ -0,0 +1,353 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid +from paddle.fluid.imperative.nn import Embedding +import paddle.fluid.framework as framework +from paddle.fluid.optimizer import SGDOptimizer +from paddle.fluid.imperative.base import to_variable +from test_imperative_base import new_program_scope +import numpy as np +import six +from paddle.fluid.backward import append_backward + + +class SimpleLSTMRNN(fluid.imperative.Layer): + def __init__(self, + hidden_size, + num_steps, + num_layers=2, + init_scale=0.1, + dropout=None): + super(SimpleLSTMRNN, self).__init__() + self._hidden_size = hidden_size + self._num_layers = num_layers + self._init_scale = init_scale + self._dropout = dropout + self._input = None + self._num_steps = num_steps + + def _build_once(self, input_embedding, init_hidden=None, init_cell=None): + self.weight_1_arr = [] + self.weight_2_arr = [] + self.bias_arr = [] + self.hidden_array = [] + self.cell_array = [] + self.mask_array = [] + + for i in range(self._num_layers): + weight_1 = fluid.layers.create_parameter( + shape=[self._hidden_size * 2, self._hidden_size * 4], + dtype="float32", + name="fc_weight1_" + str(i), + default_initializer=fluid.initializer.UniformInitializer( + low=-self._init_scale, high=self._init_scale)) + self.weight_1_arr.append(weight_1) + bias_1 = fluid.layers.create_parameter( + [self._hidden_size * 4], + dtype="float32", + name="fc_bias1_" + str(i), + default_initializer=fluid.initializer.Constant(0.0)) + self.bias_arr.append(bias_1) + + pre_hidden = fluid.layers.slice( + init_hidden, axes=[0], starts=[i], ends=[i + 1]) + pre_cell = fluid.layers.slice( + init_cell, axes=[0], starts=[i], ends=[i + 1]) + pre_hidden = fluid.layers.reshape( + pre_hidden, shape=[-1, self._hidden_size]) + pre_cell = fluid.layers.reshape( + pre_cell, shape=[-1, self._hidden_size]) + self.hidden_array.append(pre_hidden) + self.cell_array.append(pre_cell) + + def parameters(self): + parameters = list() + for param in self.weight_1_arr: + parameters.append(param) + for param in self.weight_2_arr: + parameters.append(param) + for bias in self.bias_arr: + parameters.append(bias) + return parameters + + def forward(self, input_embedding, init_hidden=None, init_cell=None): + res = [] + for index in range(self._num_steps): + self._input = fluid.layers.slice( + input_embedding, axes=[1], starts=[index], ends=[index + 1]) + self._input = fluid.layers.reshape( + self._input, shape=[-1, self._hidden_size]) + for k in range(self._num_layers): + pre_hidden = self.hidden_array[k] + pre_cell = self.cell_array[k] + weight_1 = self.weight_1_arr[k] + bias = self.bias_arr[k] + + nn = fluid.layers.concat([self._input, pre_hidden], 1) + gate_input = fluid.layers.matmul(x=nn, y=weight_1) + + gate_input = fluid.layers.elementwise_add(gate_input, bias) + i, j, f, o = fluid.layers.split( + gate_input, num_or_sections=4, dim=-1) + c = pre_cell * fluid.layers.sigmoid(f) + fluid.layers.sigmoid( + i) * fluid.layers.tanh(j) + m = fluid.layers.tanh(c) * fluid.layers.sigmoid(o) + self.hidden_array[k] = m + self.cell_array[k] = c + self._input = m + + if self._dropout is not None and self._dropout > 0.0: + self._input = fluid.layers.dropout( + self._input, + dropout_prob=self._dropout, + dropout_implementation='upscale_in_train') + res.append( + fluid.layers.reshape( + self._input, shape=[1, -1, self._hidden_size])) + real_res = fluid.layers.concat(res, 0) + real_res = fluid.layers.transpose(x=real_res, perm=[1, 0, 2]) + last_hidden = fluid.layers.concat(self.hidden_array, 1) + last_hidden = fluid.layers.reshape( + last_hidden, shape=[-1, self._num_layers, self._hidden_size]) + last_hidden = fluid.layers.transpose(x=last_hidden, perm=[1, 0, 2]) + last_cell = fluid.layers.concat(self.cell_array, 1) + last_cell = fluid.layers.reshape( + last_cell, shape=[-1, self._num_layers, self._hidden_size]) + last_cell = fluid.layers.transpose(x=last_cell, perm=[1, 0, 2]) + return real_res, last_hidden, last_cell + + +class PtbModel(fluid.imperative.Layer): + def __init__(self, + hidden_size, + vocab_size, + num_layers=2, + num_steps=20, + init_scale=0.1, + dropout=None): + super(PtbModel, self).__init__() + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.init_scale = init_scale + self.num_layers = num_layers + self.num_steps = num_steps + self.dropout = dropout + self.simple_lstm_rnn = SimpleLSTMRNN( + hidden_size, + num_steps, + num_layers=num_layers, + init_scale=init_scale, + dropout=dropout) + self.embedding = Embedding( + size=[vocab_size, hidden_size], + dtype='float32', + is_sparse=False, + param_attr=fluid.ParamAttr( + name='embedding_para', + initializer=fluid.initializer.UniformInitializer( + low=-init_scale, high=init_scale))) + self.softmax_weight = fluid.layers.create_parameter( + [self.hidden_size, self.vocab_size], + dtype="float32", + name="softmax_weight", + default_initializer=fluid.initializer.UniformInitializer( + low=-self.init_scale, high=self.init_scale)) + self.softmax_bias = fluid.layers.create_parameter( + [self.vocab_size], + dtype="float32", + name='softmax_bias', + default_initializer=fluid.initializer.UniformInitializer( + low=-self.init_scale, high=self.init_scale)) + + def _build_once(self, input, label, init_hidden, init_cell): + pass + + def parameters(self): + parameters = self.simple_lstm_rnn.parameters() + [ + self.softmax_weight, self.softmax_bias + ] + self.embedding.parameters() + return parameters + + def forward(self, input, label, init_hidden, init_cell): + + init_h = fluid.layers.reshape( + init_hidden, shape=[self.num_layers, -1, self.hidden_size]) + + init_c = fluid.layers.reshape( + init_cell, shape=[self.num_layers, -1, self.hidden_size]) + + x_emb = self.embedding(input) + x_emb = fluid.layers.reshape( + x_emb, shape=[-1, self.num_steps, self.hidden_size]) + if self.dropout is not None and self.dropout > 0.0: + x_emb = fluid.layers.dropout( + x_emb, + dropout_prob=self.drop_out, + dropout_implementation='upscale_in_train') + rnn_out, last_hidden, last_cell = self.simple_lstm_rnn(x_emb, init_h, + init_c) + rnn_out = fluid.layers.reshape( + rnn_out, shape=[-1, self.num_steps, self.hidden_size]) + projection = fluid.layers.matmul(rnn_out, self.softmax_weight) + projection = fluid.layers.elementwise_add(projection, self.softmax_bias) + projection = fluid.layers.reshape( + projection, shape=[-1, self.vocab_size]) + projection = fluid.layers.reshape( + projection, shape=[-1, self.vocab_size]) + loss = fluid.layers.softmax_with_cross_entropy( + logits=projection, label=label, soft_label=False) + loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps]) + loss = fluid.layers.reduce_mean(loss, dim=[0]) + loss = fluid.layers.reduce_sum(loss) + loss.permissions = True + + return loss, last_hidden, last_cell + + +class TestImperativePtbRnn(unittest.TestCase): + def test_ptb_rnn_cpu_float32(self): + seed = 90 + hidden_size = 10 + vocab_size = 1000 + num_layers = 1 + num_steps = 3 + init_scale = 0.1 + batch_size = 4 + + with fluid.imperative.guard(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + # TODO: marsyang1993 Change seed to + ptb_model = PtbModel( + hidden_size=hidden_size, + vocab_size=vocab_size, + num_layers=num_layers, + num_steps=num_steps, + init_scale=init_scale) + + sgd = SGDOptimizer(learning_rate=1e-3) + dy_param_updated = dict() + dy_param_init = dict() + dy_loss = None + last_hidden = None + last_cell = None + for i in range(2): + x_data = np.arange(12).reshape(4, 3).astype('int64') + y_data = np.arange(1, 13).reshape(4, 3).astype('int64') + x_data = x_data.reshape((-1, num_steps, 1)) + y_data = y_data.reshape((-1, 1)) + init_hidden_data = np.zeros( + (num_layers, batch_size, hidden_size), dtype='float32') + init_cell_data = np.zeros( + (num_layers, batch_size, hidden_size), dtype='float32') + x = to_variable(x_data) + y = to_variable(y_data) + init_hidden = to_variable(init_hidden_data) + init_cell = to_variable(init_cell_data) + dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden, + init_cell) + if i == 0: + for param in ptb_model.parameters(): + dy_param_init[param.name] = param._numpy() + dy_loss._backward() + sgd.minimize(dy_loss) + for param in ptb_model.parameters(): + dy_param_updated[param.name] = param._numpy() + # print("dy_loss is {}".format(dy_loss._numpy())) + # print("last_hidden is {}".format(last_hidden._numpy())) + # print("last_cell is {}".format(last_cell._numpy())) + + with new_program_scope(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + # TODO: marsyang1993 Change seed to + ptb_model = PtbModel( + hidden_size=hidden_size, + vocab_size=vocab_size, + num_layers=num_layers, + num_steps=num_steps, + init_scale=init_scale) + + exe = fluid.Executor(fluid.CPUPlace()) + sgd = SGDOptimizer(learning_rate=1e-3) + x = fluid.layers.data(name="x", shape=[-1, 3, 1], dtype='int64') + y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32') + init_hidden = fluid.layers.data( + name="init_hidden", shape=[1], dtype='float32') + init_cell = fluid.layers.data( + name="init_cell", shape=[1], dtype='float32') + + static_loss, static_last_hidden, static_last_cell = ptb_model( + x, y, init_hidden, init_cell) + sgd.minimize(static_loss) + static_param_updated = dict() + static_param_init = dict() + static_param_name_list = list() + for param in ptb_model.parameters(): + static_param_name_list.append(param.name) + + out = exe.run(framework.default_startup_program(), + fetch_list=static_param_name_list) + for i in range(len(static_param_name_list)): + static_param_init[static_param_name_list[i]] = out[i] + static_loss_value = None + static_last_cell_value = None + static_last_hidden_value = None + for i in range(2): + x_data = np.arange(12).reshape(4, 3).astype('int64') + y_data = np.arange(1, 13).reshape(4, 3).astype('int64') + x_data = x_data.reshape((-1, num_steps, 1)) + y_data = y_data.reshape((-1, 1)) + init_hidden_data = np.zeros( + (num_layers, batch_size, hidden_size), dtype='float32') + init_cell_data = np.zeros( + (num_layers, batch_size, hidden_size), dtype='float32') + fetch_list = [static_loss, static_last_hidden, static_last_cell] + fetch_list.extend(static_param_name_list) + out = exe.run(fluid.default_main_program(), + feed={ + "x": x_data, + "y": y_data, + "init_hidden": init_hidden_data, + "init_cell": init_cell_data + }, + fetch_list=fetch_list) + static_loss_value = out[0] + static_last_cell_value = out[1] + static_last_hidden_value = out[2] + for k in range(3, len(out)): + static_param_updated[static_param_name_list[k - 3]] = out[k] + + self.assertTrue( + np.allclose(static_loss_value.all(), dy_loss._numpy().all())) + self.assertTrue( + np.allclose(static_last_cell_value.all(), + last_cell._numpy().all())) + self.assertTrue( + np.allclose(static_last_hidden_value.all(), + last_hidden._numpy().all())) + for key, value in six.iteritems(static_param_init): + self.assertTrue( + np.allclose(value.all(), dy_param_init[key].all())) + for key, value in six.iteritems(static_param_updated): + self.assertTrue( + np.allclose(value.all(), dy_param_updated[key].all())) + + +if __name__ == '__main__': + unittest.main()