test_transformer.py 19.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
17
import tempfile
18
import time
19 20
import unittest

21
import numpy as np
22
import transformer_util as util
23 24 25 26 27
from transformer_dygraph_model import (
    CrossEntropyCriterion,
    Transformer,
    position_encoding_init,
)
28

29
import paddle
30
from paddle import fluid
31

32
trainer_count = 1
33 34 35
place = (
    fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace()
)
36
SEED = 10
37
STEP_NUM = 10
38 39 40


def train_static(args, batch_generator):
41
    paddle.enable_static()
C
cnn 已提交
42
    paddle.seed(SEED)
L
Leo Chen 已提交
43
    paddle.framework.random._manual_program_seed(SEED)
44 45
    train_prog = fluid.Program()
    startup_prog = fluid.Program()
L
Leo Chen 已提交
46

47 48 49
    with fluid.program_guard(train_prog, startup_prog):
        with fluid.unique_name.guard():
            # define input and reader
50 51 52 53 54
            input_field_names = (
                util.encoder_data_input_fields
                + util.decoder_data_input_fields[:-1]
                + util.label_data_input_fields
            )
55
            input_descs = util.get_input_descs(args)
56 57 58 59 60 61 62 63
            input_slots = [
                {
                    "name": name,
                    "shape": input_descs[name][0],
                    "dtype": input_descs[name][1],
                }
                for name in input_field_names
            ]
64 65 66
            input_field = util.InputField(input_slots)
            # Define DataLoader
            data_loader = fluid.io.DataLoader.from_generator(
67 68
                input_field.feed_list, capacity=60
            )
69 70 71
            data_loader.set_batch_generator(batch_generator, places=place)
            # define model
            transformer = Transformer(
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
                args.src_vocab_size,
                args.trg_vocab_size,
                args.max_length + 1,
                args.n_layer,
                args.n_head,
                args.d_key,
                args.d_value,
                args.d_model,
                args.d_inner_hid,
                args.prepostprocess_dropout,
                args.attention_dropout,
                args.relu_dropout,
                args.preprocess_cmd,
                args.postprocess_cmd,
                args.weight_sharing,
                args.bos_idx,
                args.eos_idx,
            )
90 91 92 93
            logits = transformer(*input_field.feed_list[:7])
            # define loss
            criterion = CrossEntropyCriterion(args.label_smooth_eps)
            lbl_word, lbl_weight = input_field.feed_list[7:]
94 95 96
            sum_cost, avg_cost, token_num = criterion(
                logits, lbl_word, lbl_weight
            )
97 98
            # define optimizer
            learning_rate = fluid.layers.learning_rate_scheduler.noam_decay(
99 100 101 102 103 104 105 106
                args.d_model, args.warmup_steps, args.learning_rate
            )
            optimizer = fluid.optimizer.Adam(
                learning_rate=learning_rate,
                beta1=args.beta1,
                beta2=args.beta2,
                epsilon=float(args.eps),
            )
107 108
            optimizer.minimize(avg_cost)
            # the best cross-entropy value with label smoothing
109 110
            loss_normalizer = -(
                (1.0 - args.label_smooth_eps)
111
                * np.log(1.0 - args.label_smooth_eps)
112 113 114 115 116
                + args.label_smooth_eps
                * np.log(
                    args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20
                )
            )
117 118 119 120 121 122 123 124
    step_idx = 0
    total_batch_num = 0
    avg_loss = []
    exe = fluid.Executor(place)
    exe.run(startup_prog)
    for pass_id in range(args.epoch):
        batch_id = 0
        for feed_dict in data_loader:
125 126 127 128 129
            outs = exe.run(
                program=train_prog,
                feed=feed_dict,
                fetch_list=[sum_cost.name, token_num.name],
            )
130
            if step_idx % args.print_step == 0:
131
                sum_cost_val, token_num_val = np.array(outs[0]), np.array(
132 133
                    outs[1]
                )
134 135 136 137 138 139 140
                total_sum_cost = sum_cost_val.sum()
                total_token_num = token_num_val.sum()
                total_avg_cost = total_sum_cost / total_token_num
                avg_loss.append(total_avg_cost)
                if step_idx == 0:
                    logging.info(
                        "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
141 142 143 144 145 146 147 148 149 150
                        "normalized loss: %f, ppl: %f"
                        % (
                            step_idx,
                            pass_id,
                            batch_id,
                            total_avg_cost,
                            total_avg_cost - loss_normalizer,
                            np.exp([min(total_avg_cost, 100)]),
                        )
                    )
151 152 153 154
                    avg_batch_time = time.time()
                else:
                    logging.info(
                        "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
155 156 157 158 159 160 161 162 163 164 165
                        "normalized loss: %f, ppl: %f, speed: %.2f steps/s"
                        % (
                            step_idx,
                            pass_id,
                            batch_id,
                            total_avg_cost,
                            total_avg_cost - loss_normalizer,
                            np.exp([min(total_avg_cost, 100)]),
                            args.print_step / (time.time() - avg_batch_time),
                        )
                    )
166 167 168 169
                    avg_batch_time = time.time()
            batch_id += 1
            step_idx += 1
            total_batch_num = total_batch_num + 1
170
            if step_idx == STEP_NUM:
171
                if args.save_dygraph_model_path:
172 173 174
                    model_path = os.path.join(
                        args.save_static_model_path, "transformer"
                    )
175
                    paddle.static.save(train_prog, model_path)
176 177 178 179 180 181 182
                break
    return np.array(avg_loss)


def train_dygraph(args, batch_generator):
    with fluid.dygraph.guard(place):
        if SEED is not None:
C
cnn 已提交
183
            paddle.seed(SEED)
L
Leo Chen 已提交
184
            paddle.framework.random._manual_program_seed(SEED)
185 186 187 188 189
        # define data loader
        train_loader = fluid.io.DataLoader.from_generator(capacity=10)
        train_loader.set_batch_generator(batch_generator, places=place)
        # define model
        transformer = Transformer(
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
            args.src_vocab_size,
            args.trg_vocab_size,
            args.max_length + 1,
            args.n_layer,
            args.n_head,
            args.d_key,
            args.d_value,
            args.d_model,
            args.d_inner_hid,
            args.prepostprocess_dropout,
            args.attention_dropout,
            args.relu_dropout,
            args.preprocess_cmd,
            args.postprocess_cmd,
            args.weight_sharing,
            args.bos_idx,
            args.eos_idx,
        )
208 209 210 211
        # define loss
        criterion = CrossEntropyCriterion(args.label_smooth_eps)
        # define optimizer
        learning_rate = fluid.layers.learning_rate_scheduler.noam_decay(
212 213
            args.d_model, args.warmup_steps, args.learning_rate
        )
214 215 216 217 218 219
        # define optimizer
        optimizer = fluid.optimizer.Adam(
            learning_rate=learning_rate,
            beta1=args.beta1,
            beta2=args.beta2,
            epsilon=float(args.eps),
220 221
            parameter_list=transformer.parameters(),
        )
222 223
        # the best cross-entropy value with label smoothing
        loss_normalizer = -(
224
            (1.0 - args.label_smooth_eps) * np.log(1.0 - args.label_smooth_eps)
225 226 227
            + args.label_smooth_eps
            * np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20)
        )
228 229 230 231 232 233 234 235
        ce_time = []
        ce_ppl = []
        avg_loss = []
        step_idx = 0
        for pass_id in range(args.epoch):
            pass_start_time = time.time()
            batch_id = 0
            for input_data in train_loader():
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
                (
                    src_word,
                    src_pos,
                    src_slf_attn_bias,
                    trg_word,
                    trg_pos,
                    trg_slf_attn_bias,
                    trg_src_attn_bias,
                    lbl_word,
                    lbl_weight,
                ) = input_data
                logits = transformer(
                    src_word,
                    src_pos,
                    src_slf_attn_bias,
                    trg_word,
                    trg_pos,
                    trg_slf_attn_bias,
                    trg_src_attn_bias,
                )
256
                sum_cost, avg_cost, token_num = criterion(
257 258
                    logits, lbl_word, lbl_weight
                )
259 260 261 262 263 264 265 266 267
                avg_cost.backward()
                optimizer.minimize(avg_cost)
                transformer.clear_gradients()
                if step_idx % args.print_step == 0:
                    total_avg_cost = avg_cost.numpy() * trainer_count
                    avg_loss.append(total_avg_cost[0])
                    if step_idx == 0:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
268 269 270 271 272 273 274 275 276 277
                            "normalized loss: %f, ppl: %f"
                            % (
                                step_idx,
                                pass_id,
                                batch_id,
                                total_avg_cost,
                                total_avg_cost - loss_normalizer,
                                np.exp([min(total_avg_cost, 100)]),
                            )
                        )
278 279 280 281
                        avg_batch_time = time.time()
                    else:
                        logging.info(
                            "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
282
                            "normalized loss: %f, ppl: %f, speed: %.2f steps/s"
283 284 285 286 287 288 289 290 291 292 293
                            % (
                                step_idx,
                                pass_id,
                                batch_id,
                                total_avg_cost,
                                total_avg_cost - loss_normalizer,
                                np.exp([min(total_avg_cost, 100)]),
                                args.print_step
                                / (time.time() - avg_batch_time),
                            )
                        )
294 295 296 297
                        ce_ppl.append(np.exp([min(total_avg_cost, 100)]))
                        avg_batch_time = time.time()
                batch_id += 1
                step_idx += 1
298
                if step_idx == STEP_NUM:
299 300
                    if args.save_dygraph_model_path:
                        model_dir = os.path.join(args.save_dygraph_model_path)
301 302
                        if not os.path.exists(model_dir):
                            os.makedirs(model_dir)
303
                        paddle.save(
304
                            transformer.state_dict(),
305 306
                            os.path.join(model_dir, "transformer")
                            + '.pdparams',
307
                        )
308
                        paddle.save(
309
                            optimizer.state_dict(),
310 311
                            os.path.join(model_dir, "transformer")
                            + '.pdparams',
312
                        )
313 314 315 316 317 318
                    break
            time_consumed = time.time() - pass_start_time
            ce_time.append(time_consumed)
        return np.array(avg_loss)


319 320
def predict_dygraph(args, batch_generator):
    with fluid.dygraph.guard(place):
C
cnn 已提交
321
        paddle.seed(SEED)
L
Leo Chen 已提交
322
        paddle.framework.random._manual_program_seed(SEED)
323 324 325 326 327 328 329

        # define data loader
        test_loader = fluid.io.DataLoader.from_generator(capacity=10)
        test_loader.set_batch_generator(batch_generator, places=place)

        # define model
        transformer = Transformer(
330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
            args.src_vocab_size,
            args.trg_vocab_size,
            args.max_length + 1,
            args.n_layer,
            args.n_head,
            args.d_key,
            args.d_value,
            args.d_model,
            args.d_inner_hid,
            args.prepostprocess_dropout,
            args.attention_dropout,
            args.relu_dropout,
            args.preprocess_cmd,
            args.postprocess_cmd,
            args.weight_sharing,
            args.bos_idx,
            args.eos_idx,
        )
348 349 350

        # load the trained model
        model_dict, _ = util.load_dygraph(
351 352
            os.path.join(args.save_dygraph_model_path, "transformer")
        )
353 354 355
        # to avoid a longer length than training, reset the size of position
        # encoding to max_length
        model_dict["encoder.pos_encoder.weight"] = position_encoding_init(
356 357
            args.max_length + 1, args.d_model
        )
358
        model_dict["decoder.pos_encoder.weight"] = position_encoding_init(
359 360
            args.max_length + 1, args.d_model
        )
361 362 363 364 365 366
        transformer.load_dict(model_dict)

        # set evaluate mode
        transformer.eval()

        step_idx = 0
367
        speed_list = []
368
        for input_data in test_loader():
369 370 371 372 373 374 375
            (
                src_word,
                src_pos,
                src_slf_attn_bias,
                trg_word,
                trg_src_attn_bias,
            ) = input_data
376
            seq_ids, seq_scores = transformer.beam_search(
377 378 379 380 381 382 383 384
                src_word,
                src_pos,
                src_slf_attn_bias,
                trg_word,
                trg_src_attn_bias,
                bos_id=args.bos_idx,
                eos_id=args.eos_idx,
                beam_size=args.beam_size,
385 386
                max_len=args.max_out_len,
            )
387 388 389 390 391 392
            seq_ids = seq_ids.numpy()
            seq_scores = seq_scores.numpy()
            if step_idx % args.print_step == 0:
                if step_idx == 0:
                    logging.info(
                        "Dygraph Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f"
393 394
                        % (step_idx, seq_ids[0][0][0], seq_scores[0][0])
                    )
395 396 397 398 399
                    avg_batch_time = time.time()
                else:
                    speed = args.print_step / (time.time() - avg_batch_time)
                    speed_list.append(speed)
                    logging.info(
400
                        "Dygraph Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f steps/s"
401 402
                        % (step_idx, seq_ids[0][0][0], seq_scores[0][0], speed)
                    )
403 404
                    avg_batch_time = time.time()

405
            step_idx += 1
406
            if step_idx == STEP_NUM:
407
                break
408 409 410
        logging.info(
            "Dygraph Predict:  avg_speed: %.4f steps/s" % (np.mean(speed_list))
        )
411
        return seq_ids, seq_scores
412 413 414 415 416


def predict_static(args, batch_generator):
    test_prog = fluid.Program()
    with fluid.program_guard(test_prog):
C
cnn 已提交
417
        paddle.seed(SEED)
L
Leo Chen 已提交
418
        paddle.framework.random._manual_program_seed(SEED)
419 420

        # define input and reader
421 422 423
        input_field_names = (
            util.encoder_data_input_fields + util.fast_decoder_data_input_fields
        )
424
        input_descs = util.get_input_descs(args, 'test')
425 426 427 428 429 430 431 432
        input_slots = [
            {
                "name": name,
                "shape": input_descs[name][0],
                "dtype": input_descs[name][1],
            }
            for name in input_field_names
        ]
433 434 435

        input_field = util.InputField(input_slots)
        feed_list = input_field.feed_list
436 437 438
        loader = fluid.io.DataLoader.from_generator(
            feed_list=feed_list, capacity=10
        )
439 440 441

        # define model
        transformer = Transformer(
442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467
            args.src_vocab_size,
            args.trg_vocab_size,
            args.max_length + 1,
            args.n_layer,
            args.n_head,
            args.d_key,
            args.d_value,
            args.d_model,
            args.d_inner_hid,
            args.prepostprocess_dropout,
            args.attention_dropout,
            args.relu_dropout,
            args.preprocess_cmd,
            args.postprocess_cmd,
            args.weight_sharing,
            args.bos_idx,
            args.eos_idx,
        )

        out_ids, out_scores = transformer.beam_search(
            *feed_list,
            bos_id=args.bos_idx,
            eos_id=args.eos_idx,
            beam_size=args.beam_size,
            max_len=args.max_out_len
        )
468 469 470 471 472 473 474

    # This is used here to set dropout to the test mode.
    test_prog = test_prog.clone(for_test=True)

    # define the executor and program for training
    exe = fluid.Executor(place)

475 476 477
    util.load(
        test_prog, os.path.join(args.save_static_model_path, "transformer"), exe
    )
478 479 480 481

    loader.set_batch_generator(batch_generator, places=place)

    step_idx = 0
482
    speed_list = []
483 484 485 486 487
    for feed_dict in loader:
        seq_ids, seq_scores = exe.run(
            test_prog,
            feed=feed_dict,
            fetch_list=[out_ids.name, out_scores.name],
488 489
            return_numpy=True,
        )
490 491 492 493
        if step_idx % args.print_step == 0:
            if step_idx == 0:
                logging.info(
                    "Static Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f,"
494 495
                    % (step_idx, seq_ids[0][0][0], seq_scores[0][0])
                )
496 497 498 499 500
                avg_batch_time = time.time()
            else:
                speed = args.print_step / (time.time() - avg_batch_time)
                speed_list.append(speed)
                logging.info(
501
                    "Static Predict: step_idx: %d, 1st seq_id: %d, 1st seq_score: %.2f, speed: %.3f steps/s"
502 503
                    % (step_idx, seq_ids[0][0][0], seq_scores[0][0], speed)
                )
504 505
                avg_batch_time = time.time()

506
        step_idx += 1
507
        if step_idx == STEP_NUM:
508
            break
509 510 511
    logging.info(
        "Static Predict:  avg_speed: %.4f steps/s" % (np.mean(speed_list))
    )
512 513

    return seq_ids, seq_scores
514 515


516
class TestTransformer(unittest.TestCase):
517 518 519 520 521 522
    def setUp(self):
        self.temp_dir = tempfile.TemporaryDirectory()

    def tearDwon(self):
        self.temp_dir.cleanup()

523 524
    def prepare(self, mode='train'):
        args = util.ModelHyperParams()
525
        args.save_dygraph_model_path = os.path.join(
526 527 528 529 530 531 532 533
            self.temp_dir.name, args.save_dygraph_model_path
        )
        args.save_static_model_path = os.path.join(
            self.temp_dir.name, args.save_static_model_path
        )
        args.inference_model_dir = os.path.join(
            self.temp_dir.name, args.inference_model_dir
        )
534
        args.output_file = os.path.join(self.temp_dir.name, args.output_file)
535 536 537
        batch_generator = util.get_feed_data_reader(args, mode)
        return args, batch_generator

538
    def _test_train(self):
539 540 541
        args, batch_generator = self.prepare(mode='train')
        static_avg_loss = train_static(args, batch_generator)
        dygraph_avg_loss = train_dygraph(args, batch_generator)
542 543 544
        np.testing.assert_allclose(
            static_avg_loss, dygraph_avg_loss, rtol=1e-05
        )
545

546 547
    def _test_predict(self):
        args, batch_generator = self.prepare(mode='test')
548 549 550
        static_seq_ids, static_scores = predict_static(args, batch_generator)
        dygraph_seq_ids, dygraph_scores = predict_dygraph(args, batch_generator)

551 552
        np.testing.assert_allclose(static_seq_ids, static_seq_ids, rtol=1e-05)
        np.testing.assert_allclose(static_scores, dygraph_scores, rtol=1e-05)
553 554 555

    def test_check_result(self):
        self._test_train()
C
ccrrong 已提交
556 557
        # TODO(zhangliujie) fix predict fail due to precision misalignment
        # self._test_predict()
558

559 560

if __name__ == '__main__':
561
    unittest.main()