提交 d9175ba0 编写于 作者: M minqiyang

Polish conflict

上级 34988bf2
...@@ -145,9 +145,6 @@ def gru_decoder_with_attention(target_embedding, encoder_vec, encoder_proj, ...@@ -145,9 +145,6 @@ def gru_decoder_with_attention(target_embedding, encoder_vec, encoder_proj,
decoder_inputs = fc_1 + fc_2 decoder_inputs = fc_1 + fc_2
h, _, _ = fluid.layers.gru_unit( h, _, _ = fluid.layers.gru_unit(
input=decoder_inputs, hidden=hidden_mem, size=decoder_size * 3) input=decoder_inputs, hidden=hidden_mem, size=decoder_size * 3)
print(decoder_inputs.shape)
print(hidden_mem.shape)
print(decoder_size)
rnn.update_memory(hidden_mem, h) rnn.update_memory(hidden_mem, h)
out = fluid.layers.fc(input=h, out = fluid.layers.fc(input=h,
size=num_classes + 2, size=num_classes + 2,
...@@ -159,8 +156,6 @@ def gru_decoder_with_attention(target_embedding, encoder_vec, encoder_proj, ...@@ -159,8 +156,6 @@ def gru_decoder_with_attention(target_embedding, encoder_vec, encoder_proj,
def attention_train_net(args, data_shape, num_classes): def attention_train_net(args, data_shape, num_classes):
print("xxx")
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label_in = fluid.layers.data( label_in = fluid.layers.data(
name='label_in', shape=[1], dtype='int32', lod_level=1) name='label_in', shape=[1], dtype='int32', lod_level=1)
...@@ -298,10 +293,6 @@ def attention_infer(images, num_classes, use_cudnn=True): ...@@ -298,10 +293,6 @@ def attention_infer(images, num_classes, use_cudnn=True):
input=decoder_inputs, input=decoder_inputs,
hidden=pre_state_expanded, hidden=pre_state_expanded,
size=decoder_size * 3) size=decoder_size * 3)
print(decoder_inputs.shape)
print(pre_state_expanded.shape)
import sys
sys.stdout.flush()
current_state_with_lod = fluid.layers.lod_reset( current_state_with_lod = fluid.layers.lod_reset(
x=current_state, y=pre_score) x=current_state, y=pre_score)
......
...@@ -51,10 +51,6 @@ def train(args): ...@@ -51,10 +51,6 @@ def train(args):
train_net = attention_train_net train_net = attention_train_net
get_feeder_data = get_attention_feeder_data get_feeder_data = get_attention_feeder_data
print("train net")
import sys
sys.stdout.flush()
num_classes = None num_classes = None
num_classes = data_reader.num_classes( num_classes = data_reader.num_classes(
) if num_classes is None else num_classes ) if num_classes is None else num_classes
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册