提交 d2e1b46f 编写于 作者: L Luo Tao

update beam_search and seqToseq config, and add ExpActivation api

上级 425e5b0b
...@@ -128,12 +128,16 @@ def gru_encoder_decoder(data_conf, ...@@ -128,12 +128,16 @@ def gru_encoder_decoder(data_conf,
return out return out
decoder_group_name = "decoder_group" decoder_group_name = "decoder_group"
group_inputs=[StaticInput(input=encoded_vector,is_seq=True),
StaticInput(input=encoded_proj,is_seq=True)]
if not is_generating: if not is_generating:
trg_embedding = embedding_layer( trg_embedding = embedding_layer(
input=data_layer(name='target_language_word', input=data_layer(name='target_language_word',
size=target_dict_dim), size=target_dict_dim),
size=word_vector_dim, size=word_vector_dim,
param_attr=ParamAttr(name='_target_language_embedding')) param_attr=ParamAttr(name='_target_language_embedding'))
group_inputs.append(trg_embedding)
# For decoder equipped with attention mechanism, in training, # For decoder equipped with attention mechanism, in training,
# target embeding (the groudtruth) is the data input, # target embeding (the groudtruth) is the data input,
...@@ -142,22 +146,13 @@ def gru_encoder_decoder(data_conf, ...@@ -142,22 +146,13 @@ def gru_encoder_decoder(data_conf,
# for the recurrent_group. # for the recurrent_group.
decoder = recurrent_group(name=decoder_group_name, decoder = recurrent_group(name=decoder_group_name,
step=gru_decoder_with_attention, step=gru_decoder_with_attention,
input=[ input=group_inputs)
StaticInput(input=encoded_vector,
is_seq=True),
StaticInput(input=encoded_proj,
is_seq=True), trg_embedding
])
lbl = data_layer(name='target_language_next_word', lbl = data_layer(name='target_language_next_word',
size=target_dict_dim) size=target_dict_dim)
cost = classification_cost(input=decoder, label=lbl, ) cost = classification_cost(input=decoder, label=lbl)
outputs(cost) outputs(cost)
else: else:
gen_inputs = [StaticInput(input=encoded_vector,
is_seq=True),
StaticInput(input=encoded_proj,
is_seq=True), ]
# In generation, the decoder predicts a next target word based on # In generation, the decoder predicts a next target word based on
# the encoded source sequence and the last generated target word. # the encoded source sequence and the last generated target word.
...@@ -171,10 +166,11 @@ def gru_encoder_decoder(data_conf, ...@@ -171,10 +166,11 @@ def gru_encoder_decoder(data_conf,
size=target_dict_dim, size=target_dict_dim,
embedding_name='_target_language_embedding', embedding_name='_target_language_embedding',
embedding_size=word_vector_dim) embedding_size=word_vector_dim)
gen_inputs.append(trg_embedding) group_inputs.append(trg_embedding)
beam_gen = beam_search(name=decoder_group_name, beam_gen = beam_search(name=decoder_group_name,
step=gru_decoder_with_attention, step=gru_decoder_with_attention,
input=gen_inputs, input=group_inputs,
id_input=data_layer(name="sent_id", id_input=data_layer(name="sent_id",
size=1), size=1),
dict_file=trg_dict_path, dict_file=trg_dict_path,
......
...@@ -12,6 +12,13 @@ AbsActivation ...@@ -12,6 +12,13 @@ AbsActivation
:members: AbsActivation :members: AbsActivation
:noindex: :noindex:
ExpActivation
===============
.. automodule:: paddle.trainer_config_helpers.activations
:members: ExpActivation
:noindex:
IdentityActivation IdentityActivation
================== ==================
......
...@@ -13,96 +13,53 @@ ...@@ -13,96 +13,53 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#Todo(luotao02) This config is only used for unitest. It is out of date now, and will be updated later.
import math from paddle.trainer_config_helpers import *
beam_search = get_config_arg('beam_search', bool, False) settings(batch_size=15, learning_rate=0)
model_type("recurrent_nn")
Settings(learning_rate=0, batch_size=15, algorithm='sgd')
Inputs("sent_id", "dummy_data_input")
Outputs("predict_word")
num_words = 5 num_words = 5
beam_flag = get_config_arg('beam_search', bool, False)
DataLayer(name="sent_id", size=1, ) sent_id = data_layer(name="sent_id", size=1)
# This layer has no actual use, but only to decide batch_size in generation. # This layer has no actual use, but only to decide batch_size in generation.
# When generating, at least one Memory in RecurrentLayer MUST have a boot layer. # When generating, at least one Memory in RecurrentLayer MUST have a boot layer.
DataLayer(name="dummy_data_input", size=2, ) dummy_data = data_layer(name="dummy_data_input", size=2)
if beam_search: gen_inputs = [StaticInput(input=dummy_data, size=2),
RecurrentLayerGroupBegin("decoding_layer_group", GeneratedInput(size=num_words,
in_links=[], embedding_name="wordvec",
out_links=["predict_word"], embedding_size=num_words)]
generator=Generator(max_num_frames=10,
beam_size=2, def step(dummy_memory, predict_word):
num_results_per_sample=2, ))
else: # simplified RNN for testing
RecurrentLayerGroupBegin("decoding_layer_group", with mixed_layer(size=num_words) as layer:
in_links=[], layer += full_matrix_projection(input=predict_word,
out_links=["predict_word"], param_attr=ParamAttr(name="transtable"))
generator=Generator(max_num_frames=10, ))
dummy_memory = Memory(name="dummy_memory", with mixed_layer(size=num_words, act=ExpActivation()) as out:
size=2, out += trans_full_matrix_projection(input=layer,
boot_layer="dummy_data_input") param_attr=ParamAttr(name="wordvec"))
MixedLayer(name="dummy_memory",
size=2, return out
bias=False,
inputs=[IdentityProjection(dummy_memory)], ) beam_gen = beam_search(name="rnn_gen",
state_memory = Memory(name="state", step=step,
size=num_words, input=gen_inputs,
#boot_bias=True, id_input=sent_id,
#boot_bias_active_type = "tanh", dict_file="./trainer/tests/test_gen_dict.txt",
) result_file="./trainer/tests/dump_text.test",
bos_id=0,
predict_word_memory = Memory(name="predict_word", eos_id=num_words-1,
size=num_words, beam_size=2 if beam_flag else 1,
boot_with_const_id=0, ) num_results_per_sample=2 if beam_flag else 1,
max_length=10)
MixedLayer(
name = "word_embedding", #outputs(beam_gen)
size = num_words, # word embedding dim is the same as num_words in this test. # In this config, as dummy_data_input doesn't work on beam_gen (we can find dummy_memory
bias = False, # is read-only memory, and isn't used by other layers of step), we show the Inputs and Outputs
inputs = TableProjection(predict_word_memory, # as follows. Note that "__beam_search_predict__" is the default output name of beam_search.
initial_std=1, Inputs("sent_id","dummy_data_input")
learning_rate=0, Outputs("__beam_search_predict__")
parameter_name="wordvec"))
Layer( # simplified RNN for testing
name="state",
type="mixed",
size=num_words,
bias=False,
inputs=[FullMatrixProjection("word_embedding",
parameter_name="transtable")])
Layer(name="output",
type="mixed",
size=num_words,
active_type="exponential",
bias=False,
inputs=TransposedFullMatrixProjection("state",
initial_std=1,
learning_rate=0,
parameter_name="wordvec"), )
Layer(name="predict_word", type="maxid", inputs=["output"], )
Layer(name="eos_check",
type="eos_id",
eos_id=num_words - 1,
inputs=["predict_word"], )
RecurrentLayerGroupEnd("decoding_layer_group")
Evaluator(name="answer_printer",
type="seq_text_printer",
dict_file="./trainer/tests/test_gen_dict.txt",
result_file="./trainer/tests/dump_text.test",
inputs=[
"sent_id",
"predict_word",
], )
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
__all__ = ["TanhActivation", "SigmoidActivation", __all__ = ["TanhActivation", "SigmoidActivation",
"SoftmaxActivation", "IdentityActivation", "LinearActivation", "SoftmaxActivation", "IdentityActivation", "LinearActivation",
'SequenceSoftmaxActivation', 'SequenceSoftmaxActivation', 'ExpActivation',
"ReluActivation", "BReluActivation", "SoftReluActivation", "STanhActivation", "ReluActivation", "BReluActivation", "SoftReluActivation", "STanhActivation",
"AbsActivation", "SquareActivation", "BaseActivation"] "AbsActivation", "SquareActivation", "BaseActivation"]
...@@ -185,3 +185,12 @@ class SquareActivation(BaseActivation): ...@@ -185,3 +185,12 @@ class SquareActivation(BaseActivation):
""" """
def __init__(self): BaseActivation.__init__(self, 'square', False) def __init__(self): BaseActivation.__init__(self, 'square', False)
class ExpActivation(BaseActivation):
"""
Exponential Activation.
.. math::
f(z) = e^z.
"""
def __init__(self): BaseActivation.__init__(self, 'exponential', False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册