提交 cdba41af 编写于 作者: Y Youwei Song 提交者: hong

dygraph Embedding layer use lookuptable v2 (#21209)

* dygraph Embedding layer use lookuptable v2
test=develop

* fix test_nce
test=develop
上级 122b37ce
...@@ -71,8 +71,7 @@ class LookupTableV2OpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -71,8 +71,7 @@ class LookupTableV2OpMaker : public framework::OpProtoAndCheckerMaker {
"which is a learnable parameter."); "which is a learnable parameter.");
AddInput("Ids", AddInput("Ids",
"An input with type int64 " "An input with type int64 "
"contains the ids to be looked up in W. " "contains the ids to be looked up in W.");
"The last dimension size must be 1.");
AddOutput("Out", "The lookup results, which have the same type as W."); AddOutput("Out", "The lookup results, which have the same type as W.");
AddAttr<bool>("is_sparse", AddAttr<bool>("is_sparse",
"(boolean, default false) " "(boolean, default false) "
......
...@@ -1361,11 +1361,10 @@ class Embedding(layers.Layer): ...@@ -1361,11 +1361,10 @@ class Embedding(layers.Layer):
It automatically constructs a 2D embedding matrix based on the It automatically constructs a 2D embedding matrix based on the
input :attr:`size` (vocab_size, emb_size) and :attr:`dtype` . input :attr:`size` (vocab_size, emb_size) and :attr:`dtype` .
This layer requires the last dimension of Tensor shape must be equal to 1. The shape The shape of output Tensor is generated by appending an emb_size dimension to the
of output Tensor is generated by replacing the last dimension of the input Tensor shape last dimension of the input Tensor shape.
with emb_size.
The id in :attr:`input` must satisfy :math:`0 =< id < size[0]` , **Note:** The id in :attr:`input` must satisfy :math:`0 =< id < size[0]` ,
otherwise the program will throw an exception and exit. otherwise the program will throw an exception and exit.
.. code-block:: text .. code-block:: text
...@@ -1373,8 +1372,8 @@ class Embedding(layers.Layer): ...@@ -1373,8 +1372,8 @@ class Embedding(layers.Layer):
Case 1: Case 1:
input is a Tensor. padding_idx = -1 input is a Tensor. padding_idx = -1
input.data = [[[1], [3]], [[2], [4]], [[4], [127]]] input.data = [[1, 3], [2, 4], [4, 127]
input.shape = [3, 2, 1] input.shape = [3, 2]
Given size = [128, 16] Given size = [128, 16]
output is a Tensor: output is a Tensor:
out.shape = [3, 2, 16] out.shape = [3, 2, 16]
...@@ -1431,7 +1430,8 @@ class Embedding(layers.Layer): ...@@ -1431,7 +1430,8 @@ class Embedding(layers.Layer):
import numpy as np import numpy as np
# example 1 # example 1
inp_word = np.array([[[1]]]).astype('int64') inp_word = np.array([[2, 3, 5], [4, 2, 1]]).astype('int64')
inp_word.shape # [2, 3]
dict_size = 20 dict_size = 20
with fluid.dygraph.guard(): with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding( emb = fluid.dygraph.Embedding(
...@@ -1440,6 +1440,7 @@ class Embedding(layers.Layer): ...@@ -1440,6 +1440,7 @@ class Embedding(layers.Layer):
param_attr='emb.w', param_attr='emb.w',
is_sparse=False) is_sparse=False)
static_rlt3 = emb(base.to_variable(inp_word)) static_rlt3 = emb(base.to_variable(inp_word))
static_rlt3.shape # [2, 3, 32]
# example 2: load custom or pre-trained word vectors # example 2: load custom or pre-trained word vectors
weight_data = np.random.random(size=(128, 100)) # word vectors with numpy format weight_data = np.random.random(size=(128, 100)) # word vectors with numpy format
...@@ -1495,7 +1496,7 @@ class Embedding(layers.Layer): ...@@ -1495,7 +1496,7 @@ class Embedding(layers.Layer):
def forward(self, input): def forward(self, input):
out = self._helper.create_variable_for_type_inference(self._dtype) out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op( self._helper.append_op(
type='lookup_table', type='lookup_table_v2',
inputs={'Ids': input, inputs={'Ids': input,
'W': self._w}, 'W': self._w},
outputs={'Out': out}, outputs={'Out': out},
...@@ -1883,7 +1884,7 @@ class NCE(layers.Layer): ...@@ -1883,7 +1884,7 @@ class NCE(layers.Layer):
window_size = 5 window_size = 5
dict_size = 20 dict_size = 20
label_word = int(window_size // 2) + 1 label_word = int(window_size // 2) + 1
inp_word = np.array([[[1]], [[2]], [[3]], [[4]], [[5]]]).astype('int64') inp_word = np.array([[1], [2], [3], [4], [5]]).astype('int64')
nid_freq_arr = np.random.dirichlet(np.ones(20) * 1000).astype('float32') nid_freq_arr = np.random.dirichlet(np.ones(20) * 1000).astype('float32')
with fluid.dygraph.guard(): with fluid.dygraph.guard():
...@@ -1915,7 +1916,8 @@ class NCE(layers.Layer): ...@@ -1915,7 +1916,8 @@ class NCE(layers.Layer):
param_attr='nce.w', param_attr='nce.w',
bias_attr='nce.b') bias_attr='nce.b')
nce_loss3 = nce(embs3, words[label_word]) wl = fluid.layers.unsqueeze(words[label_word], axes=[0])
nce_loss3 = nce(embs3, wl)
""" """
......
...@@ -395,7 +395,7 @@ class OCRAttention(fluid.dygraph.Layer): ...@@ -395,7 +395,7 @@ class OCRAttention(fluid.dygraph.Layer):
backward_first = fluid.layers.reshape( backward_first = fluid.layers.reshape(
backward_first, [-1, backward_first.shape[2]], inplace=False) backward_first, [-1, backward_first.shape[2]], inplace=False)
decoder_boot = self.fc(backward_first) decoder_boot = self.fc(backward_first)
label_in = fluid.layers.reshape(label_in, [-1, 1], inplace=False) label_in = fluid.layers.reshape(label_in, [-1], inplace=False)
trg_embedding = self.embedding(label_in) trg_embedding = self.embedding(label_in)
trg_embedding = fluid.layers.reshape( trg_embedding = fluid.layers.reshape(
......
...@@ -254,7 +254,6 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -254,7 +254,6 @@ class TestDygraphPtbRnn(unittest.TestCase):
for i in range(batch_num): for i in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64') x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64') y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, 1)) y_data = y_data.reshape((-1, 1))
init_hidden_data = np.zeros( init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32') (num_layers, batch_size, hidden_size), dtype='float32')
...@@ -313,7 +312,7 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -313,7 +312,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
sgd = SGDOptimizer(learning_rate=1e-3) sgd = SGDOptimizer(learning_rate=1e-3)
x = fluid.layers.data( x = fluid.layers.data(
name="x", shape=[-1, num_steps, 1], dtype='int64') name="x", shape=[-1, num_steps], dtype='int64')
y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32') y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32')
init_hidden = fluid.layers.data( init_hidden = fluid.layers.data(
name="init_hidden", shape=[1], dtype='float32') name="init_hidden", shape=[1], dtype='float32')
......
...@@ -246,7 +246,6 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -246,7 +246,6 @@ class TestDygraphPtbRnn(unittest.TestCase):
for i in range(batch_num): for i in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64') x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64') y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, 1)) y_data = y_data.reshape((-1, 1))
init_hidden_data = np.zeros( init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32') (num_layers, batch_size, hidden_size), dtype='float32')
...@@ -328,7 +327,6 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -328,7 +327,6 @@ class TestDygraphPtbRnn(unittest.TestCase):
for i in range(batch_num): for i in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64') x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64') y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, 1)) y_data = y_data.reshape((-1, 1))
init_hidden_data = np.zeros( init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32') (num_layers, batch_size, hidden_size), dtype='float32')
...@@ -433,7 +431,6 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -433,7 +431,6 @@ class TestDygraphPtbRnn(unittest.TestCase):
for i in range(batch_num): for i in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64') x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64') y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, 1)) y_data = y_data.reshape((-1, 1))
init_hidden_data = np.zeros( init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32') (num_layers, batch_size, hidden_size), dtype='float32')
...@@ -537,7 +534,6 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -537,7 +534,6 @@ class TestDygraphPtbRnn(unittest.TestCase):
for i in range(batch_num): for i in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64') x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64') y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, 1)) y_data = y_data.reshape((-1, 1))
init_hidden_data = np.zeros( init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32') (num_layers, batch_size, hidden_size), dtype='float32')
...@@ -652,7 +648,6 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -652,7 +648,6 @@ class TestDygraphPtbRnn(unittest.TestCase):
for i in range(1): for i in range(1):
x_data = np.arange(12).reshape(4, 3).astype('int64') x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64') y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, 1)) y_data = y_data.reshape((-1, 1))
init_hidden_data = np.zeros( init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32') (num_layers, batch_size, hidden_size), dtype='float32')
...@@ -745,7 +740,6 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -745,7 +740,6 @@ class TestDygraphPtbRnn(unittest.TestCase):
for i in range(1): for i in range(1):
x_data = np.arange(12).reshape(4, 3).astype('int64') x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64') y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, 1)) y_data = y_data.reshape((-1, 1))
init_hidden_data = np.zeros( init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32') (num_layers, batch_size, hidden_size), dtype='float32')
...@@ -846,7 +840,6 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -846,7 +840,6 @@ class TestDygraphPtbRnn(unittest.TestCase):
for i in range(1): for i in range(1):
x_data = np.arange(12).reshape(4, 3).astype('int64') x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64') y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, 1)) y_data = y_data.reshape((-1, 1))
init_hidden_data = np.zeros( init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32') (num_layers, batch_size, hidden_size), dtype='float32')
......
...@@ -229,11 +229,11 @@ seq_len = ModelHyperParams.max_length ...@@ -229,11 +229,11 @@ seq_len = ModelHyperParams.max_length
# compile time. # compile time.
input_descs = { input_descs = {
# The actual data shape of src_word is: # The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch, 1] # [batch_size, max_src_len_in_batch]
"src_word": [(batch_size, seq_len, 1), "int64", 2], "src_word": [(batch_size, seq_len), "int64", 2],
# The actual data shape of src_pos is: # The actual data shape of src_pos is:
# [batch_size, max_src_len_in_batch, 1] # [batch_size, max_src_len_in_batch]
"src_pos": [(batch_size, seq_len, 1), "int64"], "src_pos": [(batch_size, seq_len), "int64"],
# This input is used to remove attention weights on paddings in the # This input is used to remove attention weights on paddings in the
# encoder. # encoder.
# The actual data shape of src_slf_attn_bias is: # The actual data shape of src_slf_attn_bias is:
...@@ -241,12 +241,12 @@ input_descs = { ...@@ -241,12 +241,12 @@ input_descs = {
"src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len, "src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"], seq_len), "float32"],
# The actual data shape of trg_word is: # The actual data shape of trg_word is:
# [batch_size, max_trg_len_in_batch, 1] # [batch_size, max_trg_len_in_batch]
"trg_word": [(batch_size, seq_len, 1), "int64", "trg_word": [(batch_size, seq_len), "int64",
2], # lod_level is only used in fast decoder. 2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is: # The actual data shape of trg_pos is:
# [batch_size, max_trg_len_in_batch, 1] # [batch_size, max_trg_len_in_batch]
"trg_pos": [(batch_size, seq_len, 1), "int64"], "trg_pos": [(batch_size, seq_len), "int64"],
# This input is used to remove attention weights on paddings and # This input is used to remove attention weights on paddings and
# subsequent words in the decoder. # subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is: # The actual data shape of trg_slf_attn_bias is:
...@@ -317,17 +317,17 @@ batch_num = 5 ...@@ -317,17 +317,17 @@ batch_num = 5
np.random.seed = 90 np.random.seed = 90
src_word_np = np.arange(1, TrainTaskConfig.batch_size * seq_len + 1).reshape( src_word_np = np.arange(1, TrainTaskConfig.batch_size * seq_len + 1).reshape(
[TrainTaskConfig.batch_size, seq_len, 1]).astype('int64') [TrainTaskConfig.batch_size, seq_len]).astype('int64')
src_pos_np = np.random.randint( src_pos_np = np.random.randint(
1, seq_len, size=(TrainTaskConfig.batch_size, seq_len, 1), dtype='int64') 1, seq_len, size=(TrainTaskConfig.batch_size, seq_len), dtype='int64')
src_slf_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size, src_slf_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size,
ModelHyperParams.n_head, seq_len, ModelHyperParams.n_head, seq_len,
seq_len).astype('float32') seq_len).astype('float32')
trg_word_np = np.arange(1, TrainTaskConfig.batch_size * seq_len + 1).reshape( trg_word_np = np.arange(1, TrainTaskConfig.batch_size * seq_len + 1).reshape(
[TrainTaskConfig.batch_size, seq_len, 1]).astype('int64') [TrainTaskConfig.batch_size, seq_len]).astype('int64')
trg_pos_np = np.random.randint( trg_pos_np = np.random.randint(
1, seq_len, size=(TrainTaskConfig.batch_size, seq_len, 1), dtype='int64') 1, seq_len, size=(TrainTaskConfig.batch_size, seq_len), dtype='int64')
trg_slf_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size, trg_slf_attn_bias_np = np.random.randn(TrainTaskConfig.batch_size,
ModelHyperParams.n_head, seq_len, ModelHyperParams.n_head, seq_len,
seq_len).astype('float32') seq_len).astype('float32')
......
...@@ -842,7 +842,7 @@ class TestLayer(LayerTest): ...@@ -842,7 +842,7 @@ class TestLayer(LayerTest):
window_size = 5 window_size = 5
dict_size = 20 dict_size = 20
label_word = int(window_size // 2) + 1 label_word = int(window_size // 2) + 1
inp_word = np.array([[[1]], [[2]], [[3]], [[4]], [[5]]]).astype('int64') inp_word = np.array([[1], [2], [3], [4], [5]]).astype('int64')
nid_freq_arr = np.random.dirichlet(np.ones(20) * 1000).astype('float32') nid_freq_arr = np.random.dirichlet(np.ones(20) * 1000).astype('float32')
seed = 1 seed = 1
with self.static_graph(): with self.static_graph():
...@@ -850,7 +850,7 @@ class TestLayer(LayerTest): ...@@ -850,7 +850,7 @@ class TestLayer(LayerTest):
for i in range(window_size): for i in range(window_size):
words.append( words.append(
layers.data( layers.data(
name='word_{0}'.format(i), shape=[1], dtype='int64')) name='word_{0}'.format(i), shape=[None], dtype='int64'))
sample_weights = layers.fill_constant( sample_weights = layers.fill_constant(
shape=[5, 1], dtype='float32', value=1) shape=[5, 1], dtype='float32', value=1)
embs = [] embs = []
...@@ -858,7 +858,7 @@ class TestLayer(LayerTest): ...@@ -858,7 +858,7 @@ class TestLayer(LayerTest):
if i == label_word: if i == label_word:
continue continue
emb = layers.embedding( emb = fluid.embedding(
input=words[i], input=words[i],
size=[dict_size, 32], size=[dict_size, 32],
param_attr='emb.w', param_attr='emb.w',
...@@ -866,8 +866,9 @@ class TestLayer(LayerTest): ...@@ -866,8 +866,9 @@ class TestLayer(LayerTest):
embs.append(emb) embs.append(emb)
embs = layers.concat(input=embs, axis=1) embs = layers.concat(input=embs, axis=1)
wl = fluid.layers.unsqueeze(words[label_word], axes=[0])
nce_loss = layers.nce(input=embs, nce_loss = layers.nce(input=embs,
label=words[label_word], label=wl,
num_total_classes=dict_size, num_total_classes=dict_size,
num_neg_samples=2, num_neg_samples=2,
sampler="custom_dist", sampler="custom_dist",
...@@ -886,7 +887,7 @@ class TestLayer(LayerTest): ...@@ -886,7 +887,7 @@ class TestLayer(LayerTest):
for i in range(window_size): for i in range(window_size):
words.append( words.append(
layers.data( layers.data(
name='word_{0}'.format(i), shape=[1], dtype='int64')) name='word_{0}'.format(i), shape=[None], dtype='int64'))
sample_weights = layers.fill_constant( sample_weights = layers.fill_constant(
shape=[5, 1], dtype='float32', value=1) shape=[5, 1], dtype='float32', value=1)
emb = nn.Embedding( emb = nn.Embedding(
...@@ -914,7 +915,8 @@ class TestLayer(LayerTest): ...@@ -914,7 +915,8 @@ class TestLayer(LayerTest):
bias_attr='nce.b', bias_attr='nce.b',
sample_weight=sample_weights) sample_weight=sample_weights)
nce_loss2 = nce(embs2, words[label_word]) wl = fluid.layers.unsqueeze(words[label_word], axes=[0])
nce_loss2 = nce(embs2, wl)
feed_dict = dict() feed_dict = dict()
for i in range(len(words)): for i in range(len(words)):
feed_dict['word_{0}'.format(i)] = inp_word[i] feed_dict['word_{0}'.format(i)] = inp_word[i]
...@@ -953,7 +955,8 @@ class TestLayer(LayerTest): ...@@ -953,7 +955,8 @@ class TestLayer(LayerTest):
bias_attr='nce.b', bias_attr='nce.b',
sample_weight=sample_weights) sample_weight=sample_weights)
dy_rlt = nce(embs3, words[label_word]) wl = fluid.layers.unsqueeze(words[label_word], axes=[0])
dy_rlt = nce(embs3, wl)
dy_rlt_value = dy_rlt.numpy() dy_rlt_value = dy_rlt.numpy()
self.assertTrue(np.allclose(static_rlt2, static_rlt)) self.assertTrue(np.allclose(static_rlt2, static_rlt))
...@@ -1004,14 +1007,15 @@ class TestLayer(LayerTest): ...@@ -1004,14 +1007,15 @@ class TestLayer(LayerTest):
bias_attr='nce2.b', bias_attr='nce2.b',
sample_weight=sample_weights) sample_weight=sample_weights)
nce1_loss = nce1(embs3, words[label_word]) wl = fluid.layers.unsqueeze(words[label_word], axes=[0])
nce2_loss = nce2(embs3, words[label_word]) nce1_loss = nce1(embs3, wl)
nce2_loss = nce2(embs3, wl)
self.assertFalse( self.assertFalse(
np.array_equal(nce1_loss.numpy(), nce2_loss.numpy())) np.array_equal(nce1_loss.numpy(), nce2_loss.numpy()))
nce2.weight.set_value(nce1.weight.numpy()) nce2.weight.set_value(nce1.weight.numpy())
nce2.bias.set_value(nce1.bias) nce2.bias.set_value(nce1.bias)
nce1_loss = nce1(embs3, words[label_word]) nce1_loss = nce1(embs3, wl)
nce2_loss = nce2(embs3, words[label_word]) nce2_loss = nce2(embs3, wl)
self.assertTrue( self.assertTrue(
np.array_equal(nce1_loss.numpy(), nce2_loss.numpy())) np.array_equal(nce1_loss.numpy(), nce2_loss.numpy()))
......
...@@ -240,7 +240,7 @@ class TestSaveLoadBase(unittest.TestCase): ...@@ -240,7 +240,7 @@ class TestSaveLoadBase(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
sgd = Adam(learning_rate=1e-3) sgd = Adam(learning_rate=1e-3)
x = fluid.layers.data( x = fluid.layers.data(
name="x", shape=[-1, num_steps, 1], dtype='int64') name="x", shape=[-1, num_steps], dtype='int64')
y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32') y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32')
init_hidden = fluid.layers.data( init_hidden = fluid.layers.data(
name="init_hidden", shape=[1], dtype='float32') name="init_hidden", shape=[1], dtype='float32')
...@@ -341,7 +341,7 @@ class TestSaveLoadPartial(unittest.TestCase): ...@@ -341,7 +341,7 @@ class TestSaveLoadPartial(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
sgd = Adam(learning_rate=1e-3) sgd = Adam(learning_rate=1e-3)
x = fluid.layers.data( x = fluid.layers.data(
name="x", shape=[-1, num_steps, 1], dtype='int64') name="x", shape=[-1, num_steps], dtype='int64')
y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32') y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32')
init_hidden = fluid.layers.data( init_hidden = fluid.layers.data(
name="init_hidden", shape=[1], dtype='float32') name="init_hidden", shape=[1], dtype='float32')
...@@ -451,7 +451,7 @@ class TestSaveLoadSetStateDict(unittest.TestCase): ...@@ -451,7 +451,7 @@ class TestSaveLoadSetStateDict(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
sgd = Adam(learning_rate=1e-3) sgd = Adam(learning_rate=1e-3)
x = fluid.layers.data( x = fluid.layers.data(
name="x", shape=[-1, num_steps, 1], dtype='int64') name="x", shape=[-1, num_steps], dtype='int64')
y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32') y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32')
init_hidden = fluid.layers.data( init_hidden = fluid.layers.data(
name="init_hidden", shape=[1], dtype='float32') name="init_hidden", shape=[1], dtype='float32')
...@@ -552,7 +552,7 @@ class TestProgramStatePartial(unittest.TestCase): ...@@ -552,7 +552,7 @@ class TestProgramStatePartial(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
sgd = Adam(learning_rate=1e-3) sgd = Adam(learning_rate=1e-3)
x = fluid.layers.data( x = fluid.layers.data(
name="x", shape=[-1, num_steps, 1], dtype='int64') name="x", shape=[-1, num_steps], dtype='int64')
y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32') y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32')
init_hidden = fluid.layers.data( init_hidden = fluid.layers.data(
name="init_hidden", shape=[1], dtype='float32') name="init_hidden", shape=[1], dtype='float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册