From b88662254b6c0163b6985c2777b2fadc8156b324 Mon Sep 17 00:00:00 2001 From: Zhang Ting <709968123@qq.com> Date: Fri, 17 Apr 2020 13:57:02 +0800 Subject: [PATCH] use 32 bit index to improve expand op (#23899) * use 32 bit index to improve expand op, test=develop * remove redundant code, test=develop --- paddle/fluid/framework/eigen.h | 22 ++++++++++++++++++++++ paddle/fluid/operators/expand_op.h | 9 ++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/eigen.h b/paddle/fluid/framework/eigen.h index 5bafa4345f4..21adcb9948b 100644 --- a/paddle/fluid/framework/eigen.h +++ b/paddle/fluid/framework/eigen.h @@ -115,5 +115,27 @@ struct EigenScalar { } }; +// Define Tensor with 32-bit index. +template +using Tensor32BitIndex = + Eigen::TensorMap, Eigen::Aligned>; + +template +Eigen::DSizes To32BitDims(const DSizes& in) { + Eigen::DSizes out; + for (int i = 0; i < DSizes::count; ++i) { + out[i] = in[i]; + } + return out; +} + +template +Tensor32BitIndex +To32BitIndex(EigenTensor in) { + using RetType = + Tensor32BitIndex; + return RetType(in.data(), To32BitDims(in.dimensions())); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/expand_op.h b/paddle/fluid/operators/expand_op.h index 4f167468ebd..053de589242 100644 --- a/paddle/fluid/operators/expand_op.h +++ b/paddle/fluid/operators/expand_op.h @@ -90,6 +90,7 @@ using EigenVector = framework::EigenVector; template using EigenTensor = framework::EigenTensor; +using framework::To32BitIndex; template class ExpandKernel : public framework::OpKernel { @@ -131,7 +132,13 @@ class ExpandKernel : public framework::OpKernel { auto y = EigenTensor::From(*out0); auto& place = *context.template device_context().eigen_device(); - y.device(place) = x.broadcast(bcast_dims); + // use 32-bit index to speed up + bool use_32bit_index = y.size() < Eigen::NumTraits::highest(); + if (use_32bit_index) { + To32BitIndex(y).device(place) = To32BitIndex(x).broadcast(bcast_dims); + } else { + y.device(place) = x.broadcast(bcast_dims); + } } }; -- GitLab