未验证 提交 278e368a 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #1354 from guoshengCS/opt-transformer-matmul

Optimize Transformer performance
......@@ -80,7 +80,7 @@ def multi_head_attention(queries,
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x=x, shape=[0, 0, n_head, hidden_size // n_head])
x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
......@@ -99,7 +99,9 @@ def multi_head_attention(queries,
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return layers.reshape(
x=trans_x, shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]])
x=trans_x,
shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
inplace=True)
def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
"""
......@@ -523,8 +525,7 @@ def transformer(src_vocab_size,
epsilon=label_smooth_eps)
cost = layers.softmax_with_cross_entropy(
logits=layers.reshape(
predict, shape=[-1, trg_vocab_size]),
logits=predict,
label=label,
soft_label=True if label_smooth_eps else False)
weighted_cost = cost * weights
......@@ -637,6 +638,9 @@ def wrap_decoder(trg_vocab_size,
preprocess_cmd,
postprocess_cmd,
caches=caches)
# Reshape to 2D tensor to use GEMM instead of BatchedGEMM
dec_output = layers.reshape(
dec_output, shape=[-1, dec_output.shape[-1]], inplace=True)
if weight_sharing:
predict = layers.matmul(
x=dec_output,
......@@ -751,7 +755,6 @@ def fast_decode(
dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
enc_output=pre_enc_output,
caches=pre_caches)
logits = layers.reshape(logits, (-1, trg_vocab_size))
topk_scores, topk_indices = layers.topk(
input=layers.softmax(logits), k=beam_size)
......
import argparse
import ast
import contextlib
import multiprocessing
import os
import six
......@@ -79,8 +80,7 @@ def parse_args():
type=lambda x: str(x.encode().decode("unicode-escape")),
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. "
"For EN-FR wordpiece data we provided, use '\x01' as token delimiter.")
"For EN-DE BPE data we provided, use spaces as token delimiter.")
parser.add_argument(
"--use_mem_opt",
type=ast.literal_eval,
......@@ -98,9 +98,14 @@ def parse_args():
help="The iteration number to run in profiling.")
parser.add_argument(
"--use_parallel_exe",
type=bool,
type=ast.literal_eval,
default=False,
help="The flag indicating whether to use ParallelExecutor.")
parser.add_argument(
"--profile_ops",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to profile operators.")
parser.add_argument(
'opts',
help='See config.py for all options',
......@@ -125,6 +130,8 @@ def parse_args():
def main(args):
train_prog = fluid.Program()
startup_prog = fluid.Program()
train_prog.random_seed = 1000
startup_prog.random_seed = 1000
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
sum_cost, avg_cost, predict, token_num, pyreader = transformer(
......@@ -243,24 +250,33 @@ def main(args):
if args.use_py_reader:
pyreader.reset()
pyreader.start()
break
return reader_time, run_time
@contextlib.contextmanager
def profile_context(profile=True):
if profile:
with profiler.profiler('All', 'total', '/tmp/profile_file'):
yield
else:
yield
# start-up
init_flag = True
run(1)
run(5)
init_flag = False
# profiling
start = time.time()
# currently only support profiling on one device
with profiler.profiler('All', 'total', '/tmp/profile_file'):
with profile_context(args.profile_ops):
reader_time, run_time = run(args.iter_num)
end = time.time()
total_time = end - start
print("Total time: {0}, reader time: {1} s, run time: {2} s".format(
total_time, np.sum(reader_time), np.sum(run_time)))
print(
"Total time: {0}, reader time: {1} s, run time: {2} s, step number: {3}".
format(total_time, np.sum(reader_time), np.sum(run_time),
args.iter_num))
if __name__ == "__main__":
......
......@@ -297,9 +297,14 @@ class DataReader(object):
infos = self._sample_infos
if self._sort_type == SortType.POOL:
reverse = True
for i in range(0, len(infos), self._pool_size):
# to avoid placing short next to long sentences
reverse = not reverse
infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size], key=lambda x: x.max_len)
infos[i:i + self._pool_size],
key=lambda x: x.max_len,
reverse=reverse)
# concat batch
batches = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册