From 218d8d8f734ad72c9aa1b3b296d55b0eb2ae0b65 Mon Sep 17 00:00:00 2001 From: Yihua Xu Date: Mon, 13 May 2019 18:44:41 +0800 Subject: [PATCH] Optimize the computing kernel of sequence_reverse operator (#17349) * Optimize the computing kernel of sequence_reverse operator. test=develop * Clean code test=develop * Fix for cpplint syntax checking. test=develop * Fix the compile warning issue. test=develop --- .../sequence_ops/sequence_reverse_op.h | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/sequence_ops/sequence_reverse_op.h b/paddle/fluid/operators/sequence_ops/sequence_reverse_op.h index 39dad2311..14e4fc9b0 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_reverse_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_reverse_op.h @@ -14,6 +14,7 @@ #pragma once +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/platform/for_range.h" @@ -109,7 +110,6 @@ class SequenceReverseOpKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(x.lod().size(), 1, "SequenceReverse Op only support one level lod."); - auto &dev_ctx = ctx.template device_context(); const size_t *lod; size_t lod_count = x.lod()[0].size(); @@ -131,10 +131,24 @@ class SequenceReverseOpKernel : public framework::OpKernel { PADDLE_ENFORCE_NE(x_data, y_data, "SequenceReverse Op does not support in-place operation"); - SequenceReverseFunctor functor(x_data, y_data, lod, lod_count, - row_numel); - platform::ForRange for_range(dev_ctx, limit); - for_range(functor); + if (platform::is_cpu_place(ctx.GetPlace())) { + for (size_t idx = 0; idx < lod_count - 1; idx++) { + auto start_pos = lod[idx]; + auto end_pos = lod[idx + 1]; + for (auto pos = start_pos; pos < end_pos; pos++) { + auto cur_pos = end_pos - pos - 1 + start_pos; + std::memcpy(y_data + pos * row_numel, x_data + cur_pos * row_numel, + row_numel * sizeof(T)); + } + } + } else { + auto &dev_ctx = ctx.template device_context(); + + SequenceReverseFunctor functor(x_data, y_data, lod, lod_count, + row_numel); + platform::ForRange for_range(dev_ctx, limit); + for_range(functor); + } } }; -- GitLab