提交 218d8d8f 编写于 作者: Y Yihua Xu 提交者: Tao Luo

Optimize the computing kernel of sequence_reverse operator (#17349)

* Optimize the computing kernel of sequence_reverse operator.

test=develop

* Clean code

test=develop

* Fix for cpplint syntax checking.

test=develop

* Fix the compile warning issue.

test=develop
上级 dcda2023
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <memory>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
...@@ -109,7 +110,6 @@ class SequenceReverseOpKernel : public framework::OpKernel<T> { ...@@ -109,7 +110,6 @@ class SequenceReverseOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(x.lod().size(), 1, PADDLE_ENFORCE_EQ(x.lod().size(), 1,
"SequenceReverse Op only support one level lod."); "SequenceReverse Op only support one level lod.");
auto &dev_ctx = ctx.template device_context<DeviceContext>();
const size_t *lod; const size_t *lod;
size_t lod_count = x.lod()[0].size(); size_t lod_count = x.lod()[0].size();
...@@ -131,10 +131,24 @@ class SequenceReverseOpKernel : public framework::OpKernel<T> { ...@@ -131,10 +131,24 @@ class SequenceReverseOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_NE(x_data, y_data, PADDLE_ENFORCE_NE(x_data, y_data,
"SequenceReverse Op does not support in-place operation"); "SequenceReverse Op does not support in-place operation");
SequenceReverseFunctor<T> functor(x_data, y_data, lod, lod_count, if (platform::is_cpu_place(ctx.GetPlace())) {
row_numel); for (size_t idx = 0; idx < lod_count - 1; idx++) {
platform::ForRange<DeviceContext> for_range(dev_ctx, limit); auto start_pos = lod[idx];
for_range(functor); auto end_pos = lod[idx + 1];
for (auto pos = start_pos; pos < end_pos; pos++) {
auto cur_pos = end_pos - pos - 1 + start_pos;
std::memcpy(y_data + pos * row_numel, x_data + cur_pos * row_numel,
row_numel * sizeof(T));
}
}
} else {
auto &dev_ctx = ctx.template device_context<DeviceContext>();
SequenceReverseFunctor<T> functor(x_data, y_data, lod, lod_count,
row_numel);
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
for_range(functor);
}
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册