From 41cee4e102b6e6d9a8e07bc5cd4a63210d01d241 Mon Sep 17 00:00:00 2001 From: guosheng Date: Wed, 10 Oct 2018 13:11:19 +0800 Subject: [PATCH] Reshape decoder output from 3D to 2D to use GEMM instead of BatchedGEMM --- .../transformer/model.py | 6 ++-- .../transformer/profile.py | 29 ++++++++++++++----- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/model.py b/fluid/neural_machine_translation/transformer/model.py index 7f537dbc..faec84ae 100644 --- a/fluid/neural_machine_translation/transformer/model.py +++ b/fluid/neural_machine_translation/transformer/model.py @@ -523,8 +523,7 @@ def transformer(src_vocab_size, epsilon=label_smooth_eps) cost = layers.softmax_with_cross_entropy( - logits=layers.reshape( - predict, shape=[-1, trg_vocab_size]), + logits=predict, label=label, soft_label=True if label_smooth_eps else False) weighted_cost = cost * weights @@ -637,6 +636,8 @@ def wrap_decoder(trg_vocab_size, preprocess_cmd, postprocess_cmd, caches=caches) + # Reshape to 2D tensor to use GEMM instead of BatchedGEMM + dec_output = layers.reshape(dec_output, shape=[-1, dec_output.shape[-1]]) if weight_sharing: predict = layers.matmul( x=dec_output, @@ -751,7 +752,6 @@ def fast_decode( dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias), enc_output=pre_enc_output, caches=pre_caches) - logits = layers.reshape(logits, (-1, trg_vocab_size)) topk_scores, topk_indices = layers.topk( input=layers.softmax(logits), k=beam_size) diff --git a/fluid/neural_machine_translation/transformer/profile.py b/fluid/neural_machine_translation/transformer/profile.py index a2ac16df..7cf7d305 100644 --- a/fluid/neural_machine_translation/transformer/profile.py +++ b/fluid/neural_machine_translation/transformer/profile.py @@ -1,5 +1,6 @@ import argparse import ast +import contextlib import multiprocessing import os import six @@ -79,8 +80,7 @@ def parse_args(): type=lambda x: str(x.encode().decode("unicode-escape")), default=" ", help="The delimiter used to split tokens in source or target sentences. " - "For EN-DE BPE data we provided, use spaces as token delimiter. " - "For EN-FR wordpiece data we provided, use '\x01' as token delimiter.") + "For EN-DE BPE data we provided, use spaces as token delimiter.") parser.add_argument( "--use_mem_opt", type=ast.literal_eval, @@ -98,9 +98,14 @@ def parse_args(): help="The iteration number to run in profiling.") parser.add_argument( "--use_parallel_exe", - type=bool, + type=ast.literal_eval, default=False, help="The flag indicating whether to use ParallelExecutor.") + parser.add_argument( + "--profile_ops", + type=ast.literal_eval, + default=True, + help="The flag indicating whether to profile operators.") parser.add_argument( 'opts', help='See config.py for all options', @@ -247,20 +252,30 @@ def main(args): return reader_time, run_time + @contextlib.contextmanager + def profile_context(profile=True): + if profile: + with profiler.profiler('All', 'total', '/tmp/profile_file'): + yield + else: + yield + # start-up init_flag = True - run(1) + run(5) init_flag = False # profiling start = time.time() # currently only support profiling on one device - with profiler.profiler('All', 'total', '/tmp/profile_file'): + with profile_context(args.profile_ops): reader_time, run_time = run(args.iter_num) end = time.time() total_time = end - start - print("Total time: {0}, reader time: {1} s, run time: {2} s".format( - total_time, np.sum(reader_time), np.sum(run_time))) + print( + "Total time: {0}, reader time: {1} s, run time: {2} s, step number: {3}". + format(total_time, np.sum(reader_time), np.sum(run_time), + args.iter_num)) if __name__ == "__main__": -- GitLab