未验证 提交 8650f6ff 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #12898 from luotao1/expand

remove broadcast in sequence_expand
...@@ -53,25 +53,27 @@ struct SequenceExpandFunctor<platform::CPUDeviceContext, T> { ...@@ -53,25 +53,27 @@ struct SequenceExpandFunctor<platform::CPUDeviceContext, T> {
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/ const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
LoDTensor* out) { LoDTensor* out) {
int out_offset = 0; int out_offset = 0;
auto& eigen_place = *context.eigen_device(); int x_item_length = x.numel() / x.dims()[0];
auto out_data = out->data<T>();
auto x_data = x.data<T>();
for (size_t i = 1; i < ref_lod.size(); ++i) { for (size_t i = 1; i < ref_lod.size(); ++i) {
int repeat_num = ref_lod[i] - ref_lod[i - 1]; int repeat_num = ref_lod[i] - ref_lod[i - 1];
int x_start = x_lod[i - 1]; int x_start = x_lod[i - 1];
int x_end = x_lod[i]; int x_end = x_lod[i];
int x_seq_len = x_end - x_start; int x_seq_len = x_end - x_start;
if (repeat_num > 0) { if (repeat_num > 0) {
auto x_sub_tensor = x.Slice(x_start, x_end);
x_sub_tensor.Resize({1, x_sub_tensor.numel()});
int out_start = out_offset; int out_start = out_offset;
if (out->lod().size() == 1) { if (out->lod().size() == 1) {
out_start = out->lod()[0][out_offset]; out_start = out->lod()[0][out_offset];
} }
auto out_sub_tensor = for (int j = 0; j < repeat_num; j++) {
out->Slice(out_start, out_start + x_seq_len * repeat_num); for (int k = 0; k < x_seq_len; k++) {
out_sub_tensor.Resize({repeat_num, x_sub_tensor.dims()[1]}); for (int l = 0; l < x_item_length; l++) {
EigenMatrix<T>::From(out_sub_tensor).device(eigen_place) = out_data[(out_start + j * x_seq_len + k) * x_item_length + l] =
EigenMatrix<T>::From(x_sub_tensor) x_data[(x_start + k) * x_item_length + l];
.broadcast(Eigen::array<int, 2>({{repeat_num, 1}})); }
}
}
} }
out_offset += repeat_num; out_offset += repeat_num;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册