Created by: zhoukunsheng
构建网络时生成了序列的表示,其shape为(batch, seq_len, dim),同时生成了一个shape为(batch, seq_len) 的一个binary index,想要依据该binary index 从序列表示中gather 一部分tensor。请问是否可以实现呢?