提交 218d8d8f 编写于 作者: Y Yihua Xu 提交者: Tao Luo

Optimize the computing kernel of sequence_reverse operator (#17349)

* Optimize the computing kernel of sequence_reverse operator.

test=develop

* Clean code

test=develop

* Fix for cpplint syntax checking.

test=develop

* Fix the compile warning issue.

test=develop
上级 dcda2023
......@@ -14,6 +14,7 @@
#pragma once
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/platform/for_range.h"
......@@ -109,7 +110,6 @@ class SequenceReverseOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(x.lod().size(), 1,
"SequenceReverse Op only support one level lod.");
auto &dev_ctx = ctx.template device_context<DeviceContext>();
const size_t *lod;
size_t lod_count = x.lod()[0].size();
......@@ -131,10 +131,24 @@ class SequenceReverseOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_NE(x_data, y_data,
"SequenceReverse Op does not support in-place operation");
SequenceReverseFunctor<T> functor(x_data, y_data, lod, lod_count,
row_numel);
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
for_range(functor);
if (platform::is_cpu_place(ctx.GetPlace())) {
for (size_t idx = 0; idx < lod_count - 1; idx++) {
auto start_pos = lod[idx];
auto end_pos = lod[idx + 1];
for (auto pos = start_pos; pos < end_pos; pos++) {
auto cur_pos = end_pos - pos - 1 + start_pos;
std::memcpy(y_data + pos * row_numel, x_data + cur_pos * row_numel,
row_numel * sizeof(T));
}
}
} else {
auto &dev_ctx = ctx.template device_context<DeviceContext>();
SequenceReverseFunctor<T> functor(x_data, y_data, lod, lod_count,
row_numel);
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
for_range(functor);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册