未验证 提交 fe86771a 编写于 作者: L liu zhengxi 提交者: GitHub

[Migrate Fluid] Migrate Decoder, BeamSearchDecoder (#48754)

上级 a5999d83
此差异已折叠。
......@@ -151,23 +151,15 @@ class Decoder:
if self.decoding_strategy == "beam_search":
beam_size = kwargs.get("beam_size", 4)
encoder_output = (
layers.BeamSearchDecoder.tile_beam_merge_with_batch(
encoder_output = BeamSearchDecoder.tile_beam_merge_with_batch(
encoder_output, beam_size
)
)
encoder_padding_mask = (
layers.BeamSearchDecoder.tile_beam_merge_with_batch(
encoder_padding_mask = BeamSearchDecoder.tile_beam_merge_with_batch(
encoder_padding_mask, beam_size
)
)
decoder = layers.BeamSearchDecoder(
decoder = BeamSearchDecoder(
cell=self.decoder_cell, output_fn=output_layer, **kwargs
)
else:
decoder = layers.BasicDecoder(
self.decoder_cell, helper, output_fn=output_layer
)
(
decoder_output,
......@@ -535,130 +527,6 @@ class TestDynamicDecode(unittest.TestCase):
)
self.exe = Executor(place)
def test_mle_train(self):
paddle.enable_static()
self.model_hparams["decoding_strategy"] = "train_greedy"
agent = SeqPGAgent(
model_cls=Seq2SeqModel,
alg_cls=MLE,
model_hparams=self.model_hparams,
alg_hparams={"lr": 0.001},
executor=self.exe,
main_program=fluid.Program(),
startup_program=fluid.Program(),
seed=123,
)
self.exe.run(agent.startup_program)
for iter_idx in range(self.iter_num):
reward, cost = agent.learn(
{
"src": self.data["src"][
iter_idx
* self.batch_size : (iter_idx + 1)
* self.batch_size,
:,
],
"src_sequence_length": self.data["src_sequence_length"][
iter_idx
* self.batch_size : (iter_idx + 1)
* self.batch_size
],
"trg": self.data["trg"][
iter_idx
* self.batch_size : (iter_idx + 1)
* self.batch_size,
:,
],
"trg_sequence_length": self.data["trg_sequence_length"][
iter_idx
* self.batch_size : (iter_idx + 1)
* self.batch_size
],
"label": self.data["label"][
iter_idx
* self.batch_size : (iter_idx + 1)
* self.batch_size
],
},
fetch_list=[agent.cost, agent.cost],
)
print(
"iter_idx: %d, reward: %f, cost: %f"
% (iter_idx, reward.mean(), cost)
)
def test_greedy_train(self):
paddle.enable_static()
self.model_hparams["decoding_strategy"] = "infer_greedy"
agent = SeqPGAgent(
model_cls=Seq2SeqModel,
alg_cls=PolicyGradient,
model_hparams=self.model_hparams,
alg_hparams={"lr": 0.001},
executor=self.exe,
main_program=fluid.Program(),
startup_program=fluid.Program(),
seed=123,
)
self.exe.run(agent.startup_program)
for iter_idx in range(self.iter_num):
reward, cost = agent.learn(
{
"src": self.data["src"][
iter_idx
* self.batch_size : (iter_idx + 1)
* self.batch_size,
:,
],
"src_sequence_length": self.data["src_sequence_length"][
iter_idx
* self.batch_size : (iter_idx + 1)
* self.batch_size
],
},
fetch_list=[agent.reward, agent.cost],
)
print(
"iter_idx: %d, reward: %f, cost: %f"
% (iter_idx, reward.mean(), cost)
)
def test_sample_train(self):
paddle.enable_static()
self.model_hparams["decoding_strategy"] = "infer_sample"
agent = SeqPGAgent(
model_cls=Seq2SeqModel,
alg_cls=PolicyGradient,
model_hparams=self.model_hparams,
alg_hparams={"lr": 0.001},
executor=self.exe,
main_program=fluid.Program(),
startup_program=fluid.Program(),
seed=123,
)
self.exe.run(agent.startup_program)
for iter_idx in range(self.iter_num):
reward, cost = agent.learn(
{
"src": self.data["src"][
iter_idx
* self.batch_size : (iter_idx + 1)
* self.batch_size,
:,
],
"src_sequence_length": self.data["src_sequence_length"][
iter_idx
* self.batch_size : (iter_idx + 1)
* self.batch_size
],
},
fetch_list=[agent.reward, agent.cost],
)
print(
"iter_idx: %d, reward: %f, cost: %f"
% (iter_idx, reward.mean(), cost)
)
def test_beam_search_infer(self):
paddle.set_default_dtype("float32")
paddle.enable_static()
......@@ -693,19 +561,6 @@ class TestDynamicDecode(unittest.TestCase):
fetch_list=[output],
)[0]
def func_dynamic_basic_decoder(self):
paddle.disable_static()
src = paddle.to_tensor(np.random.randint(8, size=(8, 4)))
src_length = paddle.to_tensor(np.random.randint(8, size=(8)))
model = Seq2SeqModel(**self.model_hparams)
probs, samples, sample_length = model(src, src_length)
paddle.enable_static()
def test_dynamic_basic_decoder(self):
with _test_eager_guard():
self.func_dynamic_basic_decoder()
self.func_dynamic_basic_decoder()
class ModuleApiTest(unittest.TestCase):
@classmethod
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册