提交 41cee4e1 编写于 作者: G guosheng

Reshape decoder output from 3D to 2D to use GEMM instead of BatchedGEMM

上级 1080d1b0
...@@ -523,8 +523,7 @@ def transformer(src_vocab_size, ...@@ -523,8 +523,7 @@ def transformer(src_vocab_size,
epsilon=label_smooth_eps) epsilon=label_smooth_eps)
cost = layers.softmax_with_cross_entropy( cost = layers.softmax_with_cross_entropy(
logits=layers.reshape( logits=predict,
predict, shape=[-1, trg_vocab_size]),
label=label, label=label,
soft_label=True if label_smooth_eps else False) soft_label=True if label_smooth_eps else False)
weighted_cost = cost * weights weighted_cost = cost * weights
...@@ -637,6 +636,8 @@ def wrap_decoder(trg_vocab_size, ...@@ -637,6 +636,8 @@ def wrap_decoder(trg_vocab_size,
preprocess_cmd, preprocess_cmd,
postprocess_cmd, postprocess_cmd,
caches=caches) caches=caches)
# Reshape to 2D tensor to use GEMM instead of BatchedGEMM
dec_output = layers.reshape(dec_output, shape=[-1, dec_output.shape[-1]])
if weight_sharing: if weight_sharing:
predict = layers.matmul( predict = layers.matmul(
x=dec_output, x=dec_output,
...@@ -751,7 +752,6 @@ def fast_decode( ...@@ -751,7 +752,6 @@ def fast_decode(
dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias), dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
enc_output=pre_enc_output, enc_output=pre_enc_output,
caches=pre_caches) caches=pre_caches)
logits = layers.reshape(logits, (-1, trg_vocab_size))
topk_scores, topk_indices = layers.topk( topk_scores, topk_indices = layers.topk(
input=layers.softmax(logits), k=beam_size) input=layers.softmax(logits), k=beam_size)
......
import argparse import argparse
import ast import ast
import contextlib
import multiprocessing import multiprocessing
import os import os
import six import six
...@@ -79,8 +80,7 @@ def parse_args(): ...@@ -79,8 +80,7 @@ def parse_args():
type=lambda x: str(x.encode().decode("unicode-escape")), type=lambda x: str(x.encode().decode("unicode-escape")),
default=" ", default=" ",
help="The delimiter used to split tokens in source or target sentences. " help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. " "For EN-DE BPE data we provided, use spaces as token delimiter.")
"For EN-FR wordpiece data we provided, use '\x01' as token delimiter.")
parser.add_argument( parser.add_argument(
"--use_mem_opt", "--use_mem_opt",
type=ast.literal_eval, type=ast.literal_eval,
...@@ -98,9 +98,14 @@ def parse_args(): ...@@ -98,9 +98,14 @@ def parse_args():
help="The iteration number to run in profiling.") help="The iteration number to run in profiling.")
parser.add_argument( parser.add_argument(
"--use_parallel_exe", "--use_parallel_exe",
type=bool, type=ast.literal_eval,
default=False, default=False,
help="The flag indicating whether to use ParallelExecutor.") help="The flag indicating whether to use ParallelExecutor.")
parser.add_argument(
"--profile_ops",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to profile operators.")
parser.add_argument( parser.add_argument(
'opts', 'opts',
help='See config.py for all options', help='See config.py for all options',
...@@ -247,20 +252,30 @@ def main(args): ...@@ -247,20 +252,30 @@ def main(args):
return reader_time, run_time return reader_time, run_time
@contextlib.contextmanager
def profile_context(profile=True):
if profile:
with profiler.profiler('All', 'total', '/tmp/profile_file'):
yield
else:
yield
# start-up # start-up
init_flag = True init_flag = True
run(1) run(5)
init_flag = False init_flag = False
# profiling # profiling
start = time.time() start = time.time()
# currently only support profiling on one device # currently only support profiling on one device
with profiler.profiler('All', 'total', '/tmp/profile_file'): with profile_context(args.profile_ops):
reader_time, run_time = run(args.iter_num) reader_time, run_time = run(args.iter_num)
end = time.time() end = time.time()
total_time = end - start total_time = end - start
print("Total time: {0}, reader time: {1} s, run time: {2} s".format( print(
total_time, np.sum(reader_time), np.sum(run_time))) "Total time: {0}, reader time: {1} s, run time: {2} s, step number: {3}".
format(total_time, np.sum(reader_time), np.sum(run_time),
args.iter_num))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册