未验证 提交 278e368a 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #1354 from guoshengCS/opt-transformer-matmul

Optimize Transformer performance
...@@ -80,7 +80,7 @@ def multi_head_attention(queries, ...@@ -80,7 +80,7 @@ def multi_head_attention(queries,
# The value 0 in shape attr means copying the corresponding dimension # The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size. # size of the input as the output dimension size.
reshaped = layers.reshape( reshaped = layers.reshape(
x=x, shape=[0, 0, n_head, hidden_size // n_head]) x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
# permuate the dimensions into: # permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head] # [batch_size, n_head, max_sequence_len, hidden_size_per_head]
...@@ -99,7 +99,9 @@ def multi_head_attention(queries, ...@@ -99,7 +99,9 @@ def multi_head_attention(queries,
# The value 0 in shape attr means copying the corresponding dimension # The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size. # size of the input as the output dimension size.
return layers.reshape( return layers.reshape(
x=trans_x, shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]]) x=trans_x,
shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
inplace=True)
def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate): def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
""" """
...@@ -523,8 +525,7 @@ def transformer(src_vocab_size, ...@@ -523,8 +525,7 @@ def transformer(src_vocab_size,
epsilon=label_smooth_eps) epsilon=label_smooth_eps)
cost = layers.softmax_with_cross_entropy( cost = layers.softmax_with_cross_entropy(
logits=layers.reshape( logits=predict,
predict, shape=[-1, trg_vocab_size]),
label=label, label=label,
soft_label=True if label_smooth_eps else False) soft_label=True if label_smooth_eps else False)
weighted_cost = cost * weights weighted_cost = cost * weights
...@@ -637,6 +638,9 @@ def wrap_decoder(trg_vocab_size, ...@@ -637,6 +638,9 @@ def wrap_decoder(trg_vocab_size,
preprocess_cmd, preprocess_cmd,
postprocess_cmd, postprocess_cmd,
caches=caches) caches=caches)
# Reshape to 2D tensor to use GEMM instead of BatchedGEMM
dec_output = layers.reshape(
dec_output, shape=[-1, dec_output.shape[-1]], inplace=True)
if weight_sharing: if weight_sharing:
predict = layers.matmul( predict = layers.matmul(
x=dec_output, x=dec_output,
...@@ -751,7 +755,6 @@ def fast_decode( ...@@ -751,7 +755,6 @@ def fast_decode(
dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias), dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
enc_output=pre_enc_output, enc_output=pre_enc_output,
caches=pre_caches) caches=pre_caches)
logits = layers.reshape(logits, (-1, trg_vocab_size))
topk_scores, topk_indices = layers.topk( topk_scores, topk_indices = layers.topk(
input=layers.softmax(logits), k=beam_size) input=layers.softmax(logits), k=beam_size)
......
import argparse import argparse
import ast import ast
import contextlib
import multiprocessing import multiprocessing
import os import os
import six import six
...@@ -79,8 +80,7 @@ def parse_args(): ...@@ -79,8 +80,7 @@ def parse_args():
type=lambda x: str(x.encode().decode("unicode-escape")), type=lambda x: str(x.encode().decode("unicode-escape")),
default=" ", default=" ",
help="The delimiter used to split tokens in source or target sentences. " help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. " "For EN-DE BPE data we provided, use spaces as token delimiter.")
"For EN-FR wordpiece data we provided, use '\x01' as token delimiter.")
parser.add_argument( parser.add_argument(
"--use_mem_opt", "--use_mem_opt",
type=ast.literal_eval, type=ast.literal_eval,
...@@ -98,9 +98,14 @@ def parse_args(): ...@@ -98,9 +98,14 @@ def parse_args():
help="The iteration number to run in profiling.") help="The iteration number to run in profiling.")
parser.add_argument( parser.add_argument(
"--use_parallel_exe", "--use_parallel_exe",
type=bool, type=ast.literal_eval,
default=False, default=False,
help="The flag indicating whether to use ParallelExecutor.") help="The flag indicating whether to use ParallelExecutor.")
parser.add_argument(
"--profile_ops",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to profile operators.")
parser.add_argument( parser.add_argument(
'opts', 'opts',
help='See config.py for all options', help='See config.py for all options',
...@@ -125,6 +130,8 @@ def parse_args(): ...@@ -125,6 +130,8 @@ def parse_args():
def main(args): def main(args):
train_prog = fluid.Program() train_prog = fluid.Program()
startup_prog = fluid.Program() startup_prog = fluid.Program()
train_prog.random_seed = 1000
startup_prog.random_seed = 1000
with fluid.program_guard(train_prog, startup_prog): with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
sum_cost, avg_cost, predict, token_num, pyreader = transformer( sum_cost, avg_cost, predict, token_num, pyreader = transformer(
...@@ -243,24 +250,33 @@ def main(args): ...@@ -243,24 +250,33 @@ def main(args):
if args.use_py_reader: if args.use_py_reader:
pyreader.reset() pyreader.reset()
pyreader.start() pyreader.start()
break
return reader_time, run_time return reader_time, run_time
@contextlib.contextmanager
def profile_context(profile=True):
if profile:
with profiler.profiler('All', 'total', '/tmp/profile_file'):
yield
else:
yield
# start-up # start-up
init_flag = True init_flag = True
run(1) run(5)
init_flag = False init_flag = False
# profiling # profiling
start = time.time() start = time.time()
# currently only support profiling on one device # currently only support profiling on one device
with profiler.profiler('All', 'total', '/tmp/profile_file'): with profile_context(args.profile_ops):
reader_time, run_time = run(args.iter_num) reader_time, run_time = run(args.iter_num)
end = time.time() end = time.time()
total_time = end - start total_time = end - start
print("Total time: {0}, reader time: {1} s, run time: {2} s".format( print(
total_time, np.sum(reader_time), np.sum(run_time))) "Total time: {0}, reader time: {1} s, run time: {2} s, step number: {3}".
format(total_time, np.sum(reader_time), np.sum(run_time),
args.iter_num))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -297,9 +297,14 @@ class DataReader(object): ...@@ -297,9 +297,14 @@ class DataReader(object):
infos = self._sample_infos infos = self._sample_infos
if self._sort_type == SortType.POOL: if self._sort_type == SortType.POOL:
reverse = True
for i in range(0, len(infos), self._pool_size): for i in range(0, len(infos), self._pool_size):
# to avoid placing short next to long sentences
reverse = not reverse
infos[i:i + self._pool_size] = sorted( infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size], key=lambda x: x.max_len) infos[i:i + self._pool_size],
key=lambda x: x.max_len,
reverse=reverse)
# concat batch # concat batch
batches = [] batches = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册