diff --git a/paddle/fluid/operators/math/math_function_impl.h b/paddle/fluid/operators/math/math_function_impl.h index d1127ce4a246136cdd1385ef09d905efe63178d8..693d5620460e1fe6f6d82bd0749b0780b64841f5 100644 --- a/paddle/fluid/operators/math/math_function_impl.h +++ b/paddle/fluid/operators/math/math_function_impl.h @@ -21,6 +21,8 @@ namespace paddle { namespace operators { namespace math { +using framework::To32BitIndex; + template void SetConstant::operator()(const DeviceContext& context, framework::Tensor* tensor, @@ -40,7 +42,15 @@ void Transpose::operator()( auto eigen_in = framework::EigenTensor::From(in); auto eigen_out = framework::EigenTensor::From(*out); auto* dev = context.eigen_device(); - eigen_out.device(*dev) = eigen_in.shuffle(permute); + // use 32bit index to speed up computation + bool use_32bit_index = eigen_out.size() < Eigen::NumTraits::highest(); + bool is_gpu_place = platform::is_gpu_place(context.GetPlace()); + if (use_32bit_index && is_gpu_place) { + To32BitIndex(eigen_out).device(*dev) = + To32BitIndex(eigen_in).shuffle(permute); + } else { + eigen_out.device(*dev) = eigen_in.shuffle(permute); + } } template