Created by: NHZlX
We set the same dims of Multihead matmul op 's input to output original.
eg:
auto dim_input = context->GetInputDim("Input");
context->SetOutputDim("Out", dim_input);
context->ShareLoD("Input", /*->*/ "Out");
The setting above is wrong.
Let's first introduce some noun variables in Multihead matmul op:
1) Batch : the batch size.
2) SeqLen: the sequence length.
3) HiddenSize: the second dims size of the Embedding table.
4) HeadNumber: the head number of the encoder.
5) HeadSize: the size of each Head.
The Multihead matmul op's Input Dims is [Batch * SeqLen * HiddenSize], the Output Dim should be [Batch * SeqLen * (HeadNumber * HeadSize)].