machine_translation.py 23.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow.python.framework import dtypes
from tensorflow.python.layers.core import Dense
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops.rnn_cell_impl import RNNCell, BasicLSTMCell
from tensorflow.python.ops.rnn_cell_impl import LSTMStateTuple
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
from tensorflow.python.ops import array_ops
from tensorflow.python.util import nest
import tensorflow.contrib.seq2seq as seq2seq
from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder
import numpy as np
import os
import argparse
import time

import paddle.v2 as paddle

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
    "--embedding_dim",
    type=int,
    default=512,
    help="The dimension of embedding table. (default: %(default)d)")
parser.add_argument(
    "--encoder_size",
    type=int,
    default=512,
    help="The size of encoder bi-rnn unit. (default: %(default)d)")
parser.add_argument(
    "--decoder_size",
    type=int,
    default=512,
    help="The size of decoder rnn unit. (default: %(default)d)")
parser.add_argument(
    "--batch_size",
    type=int,
    default=128,
    help="The sequence number of a mini-batch data. (default: %(default)d)")
parser.add_argument(
    "--dict_size",
    type=int,
    default=30000,
    help="The dictionary capacity. Dictionaries of source sequence and "
    "target dictionary have same capacity. (default: %(default)d)")
parser.add_argument(
    "--max_time_steps",
    type=int,
    default=81,
    help="Max number of time steps for sequence. (default: %(default)d)")
parser.add_argument(
    "--pass_num",
    type=int,
    default=10,
    help="The pass number to train. (default: %(default)d)")
parser.add_argument(
    "--learning_rate",
    type=float,
    default=0.0002,
    help="Learning rate used to train the model. (default: %(default)f)")
parser.add_argument(
    "--infer_only", action='store_true', help="If set, run forward only.")
parser.add_argument(
    "--beam_size",
    type=int,
    default=3,
    help="The width for beam searching. (default: %(default)d)")
parser.add_argument(
    "--max_generation_length",
    type=int,
    default=250,
    help="The maximum length of sequence when doing generation. "
    "(default: %(default)d)")
parser.add_argument(
    "--save_freq",
    type=int,
    default=500,
    help="Save model checkpoint every this interation. (default: %(default)d)")
parser.add_argument(
    "--model_dir",
    type=str,
    default='./checkpoint',
    help="Path to save model checkpoints. (default: %(default)d)")

_Linear = core_rnn_cell._Linear  # pylint: disable=invalid-name

START_TOKEN_IDX = 0
END_TOKEN_IDX = 1


class LSTMCellWithSimpleAttention(RNNCell):
    """Add attention mechanism to BasicLSTMCell.
    This class is a wrapper based on tensorflow's `BasicLSTMCell`.
    """

    def __init__(self,
                 num_units,
                 encoder_vector,
                 encoder_proj,
                 source_sequence_length,
                 forget_bias=1.0,
                 state_is_tuple=True,
                 activation=None,
                 reuse=None):
        super(LSTMCellWithSimpleAttention, self).__init__(_reuse=reuse)
        if not state_is_tuple:
            logging.warn("%s: Using a concatenated state is slower and will "
                         "soon be deprecated. Use state_is_tuple=True.", self)
        self._num_units = num_units
        # set padding part to 0
        self._encoder_vector = self._reset_padding(encoder_vector,
                                                   source_sequence_length)
        self._encoder_proj = self._reset_padding(encoder_proj,
                                                 source_sequence_length)
        self._forget_bias = forget_bias
        self._state_is_tuple = state_is_tuple
        self._activation = activation or math_ops.tanh
        self._linear = None

    @property
    def state_size(self):
        return (LSTMStateTuple(self._num_units, self._num_units) \
                if self._state_is_tuple else 2 * self._num_units)

    @property
    def output_size(self):
        return self._num_units

    def zero_state(self, batch_size, dtype):
        state_size = self.state_size
        if hasattr(self, "_last_zero_state"):
            (last_state_size, last_batch_size, last_dtype,
             last_output) = getattr(self, "_last_zero_state")
            if (last_batch_size == batch_size and last_dtype == dtype and
                    last_state_size == state_size):
                return last_output
        with ops.name_scope(
                type(self).__name__ + "ZeroState", values=[batch_size]):
            output = _zero_state_tensors(state_size, batch_size, dtype)
        self._last_zero_state = (state_size, batch_size, dtype, output)
        return output

    def call(self, inputs, state):
        sigmoid = math_ops.sigmoid
        # Parameters of gates are concatenated into one multiply for efficiency.
        if self._state_is_tuple:
            c, h = state
        else:
            c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)

        # get context from encoder outputs
        context = self._simple_attention(self._encoder_vector,
                                         self._encoder_proj, h)

        if self._linear is None:
            self._linear = _Linear([inputs, context, h], 4 * self._num_units,
                                   True)
        # i = input_gate, j = new_input, f = forget_gate, o = output_gate
        i, j, f, o = array_ops.split(
            value=self._linear([inputs, context, h]),
            num_or_size_splits=4,
            axis=1)

        new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
                 self._activation(j))
        new_h = self._activation(new_c) * sigmoid(o)

        if self._state_is_tuple:
            new_state = LSTMStateTuple(new_c, new_h)
        else:
            new_state = array_ops.concat([new_c, new_h], 1)
        return new_h, new_state

    def _simple_attention(self, encoder_vec, encoder_proj, decoder_state):
        """Implement the attention function.
        The implementation has the same logic to the fluid decoder.
        """
        decoder_state_proj = tf.contrib.layers.fully_connected(
            inputs=decoder_state,
            num_outputs=self._num_units,
            activation_fn=None,
            biases_initializer=None)
        decoder_state_expand = tf.tile(
            tf.expand_dims(
                input=decoder_state_proj, axis=1),
            [1, tf.shape(encoder_proj)[1], 1])
        concated = tf.concat([decoder_state_expand, encoder_proj], axis=2)
        # need reduce the first dimension
        attention_weights = tf.contrib.layers.fully_connected(
            inputs=tf.reshape(
                concated, shape=[-1, self._num_units * 2]),
            num_outputs=1,
            activation_fn=tf.nn.tanh,
            biases_initializer=None)
        attention_weights_reshaped = tf.reshape(
            attention_weights, shape=[tf.shape(encoder_vec)[0], -1, 1])
        # normalize the attention weights using softmax
        attention_weights_normed = tf.nn.softmax(
            attention_weights_reshaped, dim=1)
        scaled = tf.multiply(attention_weights_normed, encoder_vec)
        context = tf.reduce_sum(scaled, axis=1)
        return context

    def _reset_padding(self,
                       memory,
                       memory_sequence_length,
                       check_inner_dims_defined=True):
        """Reset the padding part for encoder inputs.
        This funtion comes from tensorflow's `_prepare_memory` function.
        """
        memory = nest.map_structure(
                lambda m: ops.convert_to_tensor(m, name="memory"), memory)
        if memory_sequence_length is not None:
            memory_sequence_length = ops.convert_to_tensor(
                memory_sequence_length, name="memory_sequence_length")
        if check_inner_dims_defined:

            def _check_dims(m):
                if not m.get_shape()[2:].is_fully_defined():
                    raise ValueError(
                        "Expected memory %s to have fully defined inner dims, "
                        "but saw shape: %s" % (m.name, m.get_shape()))

            nest.map_structure(_check_dims, memory)
        if memory_sequence_length is None:
            seq_len_mask = None
        else:
            seq_len_mask = array_ops.sequence_mask(
                memory_sequence_length,
                maxlen=array_ops.shape(nest.flatten(memory)[0])[1],
                dtype=nest.flatten(memory)[0].dtype)
            seq_len_batch_size = (memory_sequence_length.shape[0].value or
                                  array_ops.shape(memory_sequence_length)[0])

        def _maybe_mask(m, seq_len_mask):
            rank = m.get_shape().ndims
            rank = rank if rank is not None else array_ops.rank(m)
            extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32)
            m_batch_size = m.shape[0].value or array_ops.shape(m)[0]
            if memory_sequence_length is not None:
                message = ("memory_sequence_length and memory tensor "
                           "batch sizes do not match.")
                with ops.control_dependencies([
                        check_ops.assert_equal(
                            seq_len_batch_size, m_batch_size, message=message)
                ]):
                    seq_len_mask = array_ops.reshape(
                        seq_len_mask,
                        array_ops.concat(
                            (array_ops.shape(seq_len_mask), extra_ones), 0))
                return m * seq_len_mask
            else:
                return m

        return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask),
                                  memory)


def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim,
                   target_dict_dim, is_generating, beam_size,
                   max_generation_length):
    src_word_idx = tf.placeholder(tf.int32, shape=[None, None])
    src_sequence_length = tf.placeholder(tf.int32, shape=[None, ])

    src_embedding_weights = tf.get_variable("source_word_embeddings",
                                            [source_dict_dim, embedding_dim])
    src_embedding = tf.nn.embedding_lookup(src_embedding_weights, src_word_idx)

    src_forward_cell = tf.nn.rnn_cell.BasicLSTMCell(encoder_size)
    src_reversed_cell = tf.nn.rnn_cell.BasicLSTMCell(encoder_size)
    # no peephole
    encoder_outputs, _ = tf.nn.bidirectional_dynamic_rnn(
        cell_fw=src_forward_cell,
        cell_bw=src_reversed_cell,
        inputs=src_embedding,
        sequence_length=src_sequence_length,
        dtype=tf.float32)

    # concat the forward outputs and backward outputs
    encoded_vec = tf.concat(encoder_outputs, axis=2)

    # project the encoder outputs to size of decoder lstm
    encoded_proj = tf.contrib.layers.fully_connected(
        inputs=tf.reshape(
            encoded_vec, shape=[-1, embedding_dim * 2]),
        num_outputs=decoder_size,
        activation_fn=None,
        biases_initializer=None)
    encoded_proj_reshape = tf.reshape(
        encoded_proj, shape=[-1, tf.shape(encoded_vec)[1], decoder_size])

    # get init state for decoder lstm's H
    backword_first = tf.slice(encoder_outputs[1], [0, 0, 0], [-1, 1, -1])
    decoder_boot = tf.contrib.layers.fully_connected(
        inputs=tf.reshape(
            backword_first, shape=[-1, embedding_dim]),
        num_outputs=decoder_size,
        activation_fn=tf.nn.tanh,
        biases_initializer=None)

    # prepare the initial state for decoder lstm
    cell_init = tf.zeros(tf.shape(decoder_boot), tf.float32)
    initial_state = LSTMStateTuple(cell_init, decoder_boot)

    # create decoder lstm cell
    decoder_cell = LSTMCellWithSimpleAttention(
        decoder_size,
        encoded_vec
        if not is_generating else seq2seq.tile_batch(encoded_vec, beam_size),
        encoded_proj_reshape if not is_generating else
        seq2seq.tile_batch(encoded_proj_reshape, beam_size),
        src_sequence_length if not is_generating else
        seq2seq.tile_batch(src_sequence_length, beam_size),
        forget_bias=0.0)

    output_layer = Dense(target_dict_dim, name='output_projection')

    if not is_generating:
        trg_word_idx = tf.placeholder(tf.int32, shape=[None, None])
        trg_sequence_length = tf.placeholder(tf.int32, shape=[None, ])
        trg_embedding_weights = tf.get_variable(
            "target_word_embeddings", [target_dict_dim, embedding_dim])
        trg_embedding = tf.nn.embedding_lookup(trg_embedding_weights,
                                               trg_word_idx)

        training_helper = seq2seq.TrainingHelper(
            inputs=trg_embedding,
            sequence_length=trg_sequence_length,
            time_major=False,
            name='training_helper')

        training_decoder = seq2seq.BasicDecoder(
            cell=decoder_cell,
            helper=training_helper,
            initial_state=initial_state,
            output_layer=output_layer)

        # get the max length of target sequence
        max_decoder_length = tf.reduce_max(trg_sequence_length)

        decoder_outputs_train, _, _ = seq2seq.dynamic_decode(
            decoder=training_decoder,
            output_time_major=False,
            impute_finished=True,
            maximum_iterations=max_decoder_length)

        decoder_logits_train = tf.identity(decoder_outputs_train.rnn_output)
        decoder_pred_train = tf.argmax(
            decoder_logits_train, axis=-1, name='decoder_pred_train')
        masks = tf.sequence_mask(
            lengths=trg_sequence_length,
            maxlen=max_decoder_length,
            dtype=tf.float32,
            name='masks')

        # place holder of label sequence
        lbl_word_idx = tf.placeholder(tf.int32, shape=[None, None])

        # compute the loss
        loss = seq2seq.sequence_loss(
            logits=decoder_logits_train,
            targets=lbl_word_idx,
            weights=masks,
            average_across_timesteps=True,
            average_across_batch=True)

        # return feeding list and loss operator
        return {
            'src_word_idx': src_word_idx,
            'src_sequence_length': src_sequence_length,
            'trg_word_idx': trg_word_idx,
            'trg_sequence_length': trg_sequence_length,
            'lbl_word_idx': lbl_word_idx
        }, loss
    else:
        start_tokens = tf.ones([tf.shape(src_word_idx)[0], ],
                               tf.int32) * START_TOKEN_IDX
        # share the same embedding weights with target word
        trg_embedding_weights = tf.get_variable(
            "target_word_embeddings", [target_dict_dim, embedding_dim])

        inference_decoder = beam_search_decoder.BeamSearchDecoder(
            cell=decoder_cell,
            embedding=lambda tokens: tf.nn.embedding_lookup(trg_embedding_weights, tokens),
            start_tokens=start_tokens,
            end_token=END_TOKEN_IDX,
            initial_state=tf.nn.rnn_cell.LSTMStateTuple(
                tf.contrib.seq2seq.tile_batch(initial_state[0], beam_size),
                tf.contrib.seq2seq.tile_batch(initial_state[1], beam_size)),
            beam_width=beam_size,
            output_layer=output_layer)

        decoder_outputs_decode, _, _ = seq2seq.dynamic_decode(
            decoder=inference_decoder,
            output_time_major=False,
            #impute_finished=True,# error occurs
            maximum_iterations=max_generation_length)

        predicted_ids = decoder_outputs_decode.predicted_ids

        return {
            'src_word_idx': src_word_idx,
            'src_sequence_length': src_sequence_length
        }, predicted_ids


def print_arguments(args):
    print('-----------  Configuration Arguments -----------')
    for arg, value in vars(args).iteritems():
        print('%s: %s' % (arg, value))
    print('------------------------------------------------')


def padding_data(data, padding_size, value):
    data = data + [value] * padding_size
    return data[:padding_size]


def save(sess, path, var_list=None, global_step=None):
    saver = tf.train.Saver(var_list)
    save_path = saver.save(sess, save_path=path, global_step=global_step)
    print('Model save at %s' % save_path)


def restore(sess, path, var_list=None):
    # var_list = None returns the list of all saveable variables
    saver = tf.train.Saver(var_list)
    saver.restore(sess, save_path=path)
    print('model restored from %s' % path)


def adapt_batch_data(data):
    src_seq = map(lambda x: x[0], data)
    trg_seq = map(lambda x: x[1], data)
    lbl_seq = map(lambda x: x[2], data)

    src_sequence_length = np.array(
        [len(seq) for seq in src_seq]).astype('int32')
    src_seq_maxlen = np.max(src_sequence_length)

    trg_sequence_length = np.array(
        [len(seq) for seq in trg_seq]).astype('int32')
    trg_seq_maxlen = np.max(trg_sequence_length)

    src_seq = np.array(
        [padding_data(seq, src_seq_maxlen, END_TOKEN_IDX)
         for seq in src_seq]).astype('int32')

    trg_seq = np.array(
        [padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX)
         for seq in trg_seq]).astype('int32')

    lbl_seq = np.array(
        [padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX)
         for seq in lbl_seq]).astype('int32')

    return {
        'src_word_idx': src_seq,
        'src_sequence_length': src_sequence_length,
        'trg_word_idx': trg_seq,
        'trg_sequence_length': trg_sequence_length,
        'lbl_word_idx': lbl_seq
    }


def train():
    feeding_dict, loss = seq_to_seq_net(
        embedding_dim=args.embedding_dim,
        encoder_size=args.encoder_size,
        decoder_size=args.decoder_size,
        source_dict_dim=args.dict_size,
        target_dict_dim=args.dict_size,
        is_generating=False,
        beam_size=args.beam_size,
        max_generation_length=args.max_generation_length)

    global_step = tf.Variable(0, trainable=False, name='global_step')
    trainable_params = tf.trainable_variables()
    optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)

    gradients = tf.gradients(loss, trainable_params)
    # may clip the parameters
    clip_gradients, _ = tf.clip_by_global_norm(gradients, 1.0)

    updates = optimizer.apply_gradients(
        zip(gradients, trainable_params), global_step=global_step)

    src_dict, trg_dict = paddle.dataset.wmt14.get_dict(args.dict_size)

    train_batch_generator = paddle.batch(
        paddle.reader.shuffle(
            paddle.dataset.wmt14.train(args.dict_size), buf_size=1000),
        batch_size=args.batch_size)

    test_batch_generator = paddle.batch(
        paddle.reader.shuffle(
            paddle.dataset.wmt14.test(args.dict_size), buf_size=1000),
        batch_size=args.batch_size)

    def do_validataion():
        total_loss = 0.0
        count = 0
        for batch_id, data in enumerate(test_batch_generator()):
            adapted_batch_data = adapt_batch_data(data)
            outputs = sess.run([loss],
                               feed_dict={
                                   item[1]: adapted_batch_data[item[0]]
                                   for item in feeding_dict.items()
                               })
            total_loss += outputs[0]
            count += 1
        return total_loss / count

    config = tf.ConfigProto(
        intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        init_g = tf.global_variables_initializer()
        init_l = tf.local_variables_initializer()
        sess.run(init_l)
        sess.run(init_g)
        for pass_id in xrange(args.pass_num):
            pass_start_time = time.time()
            words_seen = 0
            for batch_id, data in enumerate(train_batch_generator()):
                adapted_batch_data = adapt_batch_data(data)
                words_seen += np.sum(adapted_batch_data['src_sequence_length'])
                words_seen += np.sum(adapted_batch_data['trg_sequence_length'])
                outputs = sess.run([updates, loss],
                                   feed_dict={
                                       item[1]: adapted_batch_data[item[0]]
                                       for item in feeding_dict.items()
                                   })
                print("pass_id=%d, batch_id=%d, train_loss: %f" %
                      (pass_id, batch_id, outputs[1]))
            pass_end_time = time.time()
            test_loss = do_validataion()
            time_consumed = pass_end_time - pass_start_time
            words_per_sec = words_seen / time_consumed
            print("pass_id=%d, test_loss: %f, words/s: %f, sec/pass: %f" %
                  (pass_id, test_loss, words_per_sec, time_consumed))


def infer():
    feeding_dict, predicted_ids = seq_to_seq_net(
        embedding_dim=args.embedding_dim,
        encoder_size=args.encoder_size,
        decoder_size=args.decoder_size,
        source_dict_dim=args.dict_size,
        target_dict_dim=args.dict_size,
        is_generating=True,
        beam_size=args.beam_size,
        max_generation_length=args.max_generation_length)

    src_dict, trg_dict = paddle.dataset.wmt14.get_dict(args.dict_size)
    test_batch_generator = paddle.batch(
        paddle.reader.shuffle(
            paddle.dataset.wmt14.train(args.dict_size), buf_size=1000),
        batch_size=args.batch_size)

    config = tf.ConfigProto(
        intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
    with tf.Session(config=config) as sess:
        restore(sess, './checkpoint/tf_seq2seq-1500')
        for batch_id, data in enumerate(test_batch_generator()):
            src_seq = map(lambda x: x[0], data)

            source_language_seq = [
                src_dict[item] for seq in src_seq for item in seq
            ]

            src_sequence_length = np.array(
                [len(seq) for seq in src_seq]).astype('int32')
            src_seq_maxlen = np.max(src_sequence_length)
            src_seq = np.array([
                padding_data(seq, src_seq_maxlen, END_TOKEN_IDX)
                for seq in src_seq
            ]).astype('int32')

            outputs = sess.run([predicted_ids],
                               feed_dict={
                                   feeding_dict['src_word_idx']: src_seq,
                                   feeding_dict['src_sequence_length']:
                                   src_sequence_length
                               })

            print("\nDecoder result comparison: ")
            source_language_seq = ' '.join(source_language_seq).lstrip(
                '<s>').rstrip('<e>').strip()
            inference_seq = ''
            print(" --> source: " + source_language_seq)
            for item in outputs[0][0]:
                if item[0] == END_TOKEN_IDX: break
                inference_seq += ' ' + trg_dict.get(item[0], '<unk>')
            print(" --> inference: " + inference_seq)


if __name__ == '__main__':
    args = parser.parse_args()
    print_arguments(args)
    if args.infer_only:
        infer()
    else:
        train()