提交 41cee4e1 编写于 作者: G guosheng

Reshape decoder output from 3D to 2D to use GEMM instead of BatchedGEMM

上级 1080d1b0
......@@ -523,8 +523,7 @@ def transformer(src_vocab_size,
epsilon=label_smooth_eps)
cost = layers.softmax_with_cross_entropy(
logits=layers.reshape(
predict, shape=[-1, trg_vocab_size]),
logits=predict,
label=label,
soft_label=True if label_smooth_eps else False)
weighted_cost = cost * weights
......@@ -637,6 +636,8 @@ def wrap_decoder(trg_vocab_size,
preprocess_cmd,
postprocess_cmd,
caches=caches)
# Reshape to 2D tensor to use GEMM instead of BatchedGEMM
dec_output = layers.reshape(dec_output, shape=[-1, dec_output.shape[-1]])
if weight_sharing:
predict = layers.matmul(
x=dec_output,
......@@ -751,7 +752,6 @@ def fast_decode(
dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
enc_output=pre_enc_output,
caches=pre_caches)
logits = layers.reshape(logits, (-1, trg_vocab_size))
topk_scores, topk_indices = layers.topk(
input=layers.softmax(logits), k=beam_size)
......
import argparse
import ast
import contextlib
import multiprocessing
import os
import six
......@@ -79,8 +80,7 @@ def parse_args():
type=lambda x: str(x.encode().decode("unicode-escape")),
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. "
"For EN-FR wordpiece data we provided, use '\x01' as token delimiter.")
"For EN-DE BPE data we provided, use spaces as token delimiter.")
parser.add_argument(
"--use_mem_opt",
type=ast.literal_eval,
......@@ -98,9 +98,14 @@ def parse_args():
help="The iteration number to run in profiling.")
parser.add_argument(
"--use_parallel_exe",
type=bool,
type=ast.literal_eval,
default=False,
help="The flag indicating whether to use ParallelExecutor.")
parser.add_argument(
"--profile_ops",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to profile operators.")
parser.add_argument(
'opts',
help='See config.py for all options',
......@@ -247,20 +252,30 @@ def main(args):
return reader_time, run_time
@contextlib.contextmanager
def profile_context(profile=True):
if profile:
with profiler.profiler('All', 'total', '/tmp/profile_file'):
yield
else:
yield
# start-up
init_flag = True
run(1)
run(5)
init_flag = False
# profiling
start = time.time()
# currently only support profiling on one device
with profiler.profiler('All', 'total', '/tmp/profile_file'):
with profile_context(args.profile_ops):
reader_time, run_time = run(args.iter_num)
end = time.time()
total_time = end - start
print("Total time: {0}, reader time: {1} s, run time: {2} s".format(
total_time, np.sum(reader_time), np.sum(run_time)))
print(
"Total time: {0}, reader time: {1} s, run time: {2} s, step number: {3}".
format(total_time, np.sum(reader_time), np.sum(run_time),
args.iter_num))
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册