Created by: guoshengCS
Refine beam_search_op to output an extra parent_idx tensor. parent_idx is a tensor used to gather cell states at the next time step. Though lod in selected_ids can also be used to gather by sequence_expand, it is not efficient since sequence_expand has to copy lod from cpu to gpu and the computation is more complicated compared to gather. Use parent_idx(assign to gpu) and replace sequence_expand_op with gather_op speed up Transformer inference significantly. Here list parts of profiling results of Transformer inference when replacing sequence_expand with gather.
Event Calls Total Min. Max. Ave. Ratio.
--
while 1 4523.83 4523.83 4523.83 4523.83 0.50234
sequence_expand 1775 2181.03 0.030016 5.57475 1.22875 0.242188
matmul 1787 698.191 0.032384 3.60365 0.390706 0.0775293
mul 3456 613.506 0.034464 5.77315 0.177519 0.0681255
assign 1787 246.091 0.018368 3.23789 0.137712 0.0273266
concat 852 117.242 0.048992 6.85875 0.137608 0.0130189
elementwise_add 2729 101.151 0.011616 21.081 0.0370652 0.0112321
beam_search 71 74.2201 0.599296 6.65718 1.04535 0.00824162
transpose2 2592 70.8863 0.012096 2.52013 0.0273481 0.00787142
dropout 2652 70.1645 0.011168 4.49117 0.0264572 0.00779127
softmax 929 69.453 0.014336 3.38106 0.074761 0.00771226
top_k 71 69.1206 0.10224 3.11629 0.973529 0.00767535
layer_norm 1362 50.0504 0.017824 2.06554 0.0367477 0.00555774
reshape2 2735 22.3022 0.003328 5.13386 0.00815438 0.00247651
read 2 21.1472 10.5665 10.5807 10.5736 0.00234825
relu 432 15.5508 0.011328 0.48336 0.0359973 0.00172681
read_from_array 142 11.2653 0.051904 0.173312 0.0793332 0.00125093
beam_search_decode 1 11.0426 11.0426 11.0426 11.0426 0.0012262
Event Calls Total Min. Max. Ave. Ratio.
--
while 1 2837.42 2837.42 2837.42 2837.42 0.512778
matmul 1787 703.405 0.037152 3.63952 0.393624 0.127119
mul 3456 639.879 0.032576 6.84054 0.18515 0.115639
gather 1775 352.45 0.003392 7.06938 0.198564 0.0636948
assign 1858 251.248 0.00352 4.82582 0.135225 0.0454055
concat 852 135.271 0.044928 19.8283 0.158769 0.0244462
elementwise_add 2729 83.3219 0.011616 3.98928 0.030532 0.0150579
dropout 2652 77.4242 0.010848 7.27709 0.0291946 0.0139921
transpose2 2592 76.0943 0.012192 2.89446 0.0293574 0.0137517
beam_search 71 73.5428 0.631392 10.2284 1.03581 0.0132906
top_k 71 69 0.09632 3.12163 0.971831 0.0124697
softmax 929 68.5287 0.014176 3.40208 0.0737661 0.0123845
layer_norm 1362 52.1339 0.01824 2.51587 0.0382774 0.00942162
reshape2 2735 32.8726 0.003328 4.49635 0.0120192 0.00594073
relu 432 15.867 0.012256 0.471616 0.0367293 0.00286749
read 2 13.7459 6.86704 6.87885 6.87294 0.00248415
beam_search_decode 1 12.0194 12.0194 12.0194 12.0194 0.00217214
Also, make gather_op support gather form size 0 tensor.