From 0385b0a1ea8628d2a5f4e27d86f5f0c8aed57a56 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 11 Oct 2018 19:55:27 +0800 Subject: [PATCH] Accelerate SequencePool Op on SUM mode test=develop --- paddle/fluid/operators/math/CMakeLists.txt | 4 ++-- .../fluid/operators/math/sequence_pooling.cc | 21 ++++++++++++++++--- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 91101356436..5878c733c47 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -3,8 +3,8 @@ add_subdirectory(detail) endif(NOT WIN32) function(math_library TARGET) - # math_library is a function to create math library. - # The interface is the same as cc_library. + # math_library is a function to create math library. + # The interface is the same as cc_library. # But it handle split GPU/CPU code and link some common library. set(cc_srcs) set(cu_srcs) diff --git a/paddle/fluid/operators/math/sequence_pooling.cc b/paddle/fluid/operators/math/sequence_pooling.cc index 69318a6598c..235b5405fb7 100644 --- a/paddle/fluid/operators/math/sequence_pooling.cc +++ b/paddle/fluid/operators/math/sequence_pooling.cc @@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/math/sequence_pooling.h" #include + +#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/sequence_pooling.h" namespace paddle { namespace operators { @@ -180,6 +182,7 @@ class SequencePoolFunctor { } auto lod = input.lod()[0]; auto& place = *context.eigen_device(); + auto blas = math::GetBlas(context); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { Tensor in_t = input.Slice(static_cast(lod[i]), static_cast(lod[i + 1])); @@ -191,7 +194,14 @@ class SequencePoolFunctor { if (pooltype == "AVERAGE") { out_e.device(place) = in_e.mean(Eigen::array({{0}})); } else if (pooltype == "SUM") { - out_e.device(place) = in_e.sum(Eigen::array({{0}})); + if (h > 0) { + const T* in_data = in_t.data(); + T* out_data = out_t.mutable_data(context.GetPlace()); + blas.VCOPY(w, in_data, out_data); + for (int64_t r = 1; r != h; ++r) { + blas.AXPY(w, 1., in_data + r * w, out_data); + } + } } else if (pooltype == "SQRT") { out_e.device(place) = in_e.sum(Eigen::array({{0}})) / std::sqrt(static_cast(h)); @@ -223,6 +233,7 @@ class SequencePoolGradFunctor { } auto lod = in_grad->lod()[0]; auto& place = *context.eigen_device(); + auto blas = math::GetBlas(context); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { auto in_g_t = in_grad->Slice(static_cast(lod[i]), static_cast(lod[i + 1])); @@ -237,7 +248,11 @@ class SequencePoolGradFunctor { if (pooltype == "AVERAGE") { in_g_e.device(place) = (out_g_e / static_cast(h)).broadcast(bcast); } else if (pooltype == "SUM") { - in_g_e.device(place) = (out_g_e).broadcast(bcast); + const T* out_g_data = out_g_t.data(); + T* in_g_data = in_g_t.mutable_data(context.GetPlace()); + for (int r = 0; r != h; ++r) { + blas.VCOPY(w, out_g_data, in_g_data + r * w); + } } else if (pooltype == "SQRT") { in_g_e.device(place) = (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast); -- GitLab