未验证 提交 b8866225 编写于 作者: Z Zhang Ting 提交者: GitHub

use 32 bit index to improve expand op (#23899)

* use 32 bit index to improve expand op, test=develop

* remove redundant code, test=develop
上级 e21b3c27
......@@ -115,5 +115,27 @@ struct EigenScalar {
}
};
// Define Tensor with 32-bit index.
template <typename T, int D, int MajorType = Eigen::RowMajor>
using Tensor32BitIndex =
Eigen::TensorMap<Eigen::Tensor<T, D, MajorType, int>, Eigen::Aligned>;
template <typename DSizes>
Eigen::DSizes<int, DSizes::count> To32BitDims(const DSizes& in) {
Eigen::DSizes<int, DSizes::count> out;
for (int i = 0; i < DSizes::count; ++i) {
out[i] = in[i];
}
return out;
}
template <typename EigenTensor>
Tensor32BitIndex<typename EigenTensor::Scalar, EigenTensor::NumIndices>
To32BitIndex(EigenTensor in) {
using RetType =
Tensor32BitIndex<typename EigenTensor::Scalar, EigenTensor::NumIndices>;
return RetType(in.data(), To32BitDims(in.dimensions()));
}
} // namespace framework
} // namespace paddle
......@@ -90,6 +90,7 @@ using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using framework::To32BitIndex;
template <typename DeviceContext, typename T>
class ExpandKernel : public framework::OpKernel<T> {
......@@ -131,7 +132,13 @@ class ExpandKernel : public framework::OpKernel<T> {
auto y = EigenTensor<T, Rank>::From(*out0);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
y.device(place) = x.broadcast(bcast_dims);
// use 32-bit index to speed up
bool use_32bit_index = y.size() < Eigen::NumTraits<int>::highest();
if (use_32bit_index) {
To32BitIndex(y).device(place) = To32BitIndex(x).broadcast(bcast_dims);
} else {
y.device(place) = x.broadcast(bcast_dims);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册