diff --git a/paddle/fluid/framework/eigen.h b/paddle/fluid/framework/eigen.h index 5bafa4345f42a1f6209b5ee31ae6ba2ded6a899c..21adcb9948b20efe0169a9149b2afce1d485d12d 100644 --- a/paddle/fluid/framework/eigen.h +++ b/paddle/fluid/framework/eigen.h @@ -115,5 +115,27 @@ struct EigenScalar { } }; +// Define Tensor with 32-bit index. +template +using Tensor32BitIndex = + Eigen::TensorMap, Eigen::Aligned>; + +template +Eigen::DSizes To32BitDims(const DSizes& in) { + Eigen::DSizes out; + for (int i = 0; i < DSizes::count; ++i) { + out[i] = in[i]; + } + return out; +} + +template +Tensor32BitIndex +To32BitIndex(EigenTensor in) { + using RetType = + Tensor32BitIndex; + return RetType(in.data(), To32BitDims(in.dimensions())); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/operators/expand_op.h b/paddle/fluid/operators/expand_op.h index 4f167468ebdf4b5e3f82af535976c7d334ddcafe..053de589242aa9aea58826c089a8eac9d42828e6 100644 --- a/paddle/fluid/operators/expand_op.h +++ b/paddle/fluid/operators/expand_op.h @@ -90,6 +90,7 @@ using EigenVector = framework::EigenVector; template using EigenTensor = framework::EigenTensor; +using framework::To32BitIndex; template class ExpandKernel : public framework::OpKernel { @@ -131,7 +132,13 @@ class ExpandKernel : public framework::OpKernel { auto y = EigenTensor::From(*out0); auto& place = *context.template device_context().eigen_device(); - y.device(place) = x.broadcast(bcast_dims); + // use 32-bit index to speed up + bool use_32bit_index = y.size() < Eigen::NumTraits::highest(); + if (use_32bit_index) { + To32BitIndex(y).device(place) = To32BitIndex(x).broadcast(bcast_dims); + } else { + y.device(place) = x.broadcast(bcast_dims); + } } };