未验证 提交 f6f8c935 编写于 作者: H HongyuJia 提交者: GitHub

[phi Backend] Change BackendSet from uint64_t to uint32_t (#46532)

* change BackendSet from 64bits to 32bits

* fix _MSC_VER error, __lzcnt32->__lzcnt

* fix __GNUC__ error, __builtin_clzl->__builtin_clz
上级 7d238139
......@@ -36,7 +36,7 @@ class BackendSet final {
? 0
: 1ULL << (static_cast<uint8_t>(b) - 1)) {}
inline uint64_t bitset() const { return bitset_; }
inline uint32_t bitset() const { return bitset_; }
bool inline Has(Backend b) const {
PD_CHECK(b != Backend::UNDEFINED, "Backend argument can't be UNDEFINED.");
......@@ -62,8 +62,8 @@ class BackendSet final {
}
private:
constexpr BackendSet(uint64_t bitset) : bitset_(bitset) {}
uint64_t bitset_;
constexpr BackendSet(uint32_t bitset) : bitset_(bitset) {}
uint32_t bitset_;
};
} // namespace experimental
......
......@@ -68,18 +68,18 @@ BackendSet GetTensorBackendSet(const phi::TensorBase& t) {
return BackendSet(Backend::UNDEFINED);
}
std::size_t CountLeadingZeros(uint64_t val) {
std::size_t CountLeadingZeros(uint32_t val) {
#if defined(__clang__) || defined(__GNUC__)
return __builtin_clzl(val);
return __builtin_clz(val);
#elif defined(_MSC_VER)
return __lzcnt64(val);
return __lzcnt(val);
#else
if (val == 0) {
return 64;
return 32;
}
std::size_t zero_bits = 0;
for (std::size_t shift = 64 >> 1; shift; shift >>= 1) {
uint64_t tmp = val >> shift;
for (std::size_t shift = 32 >> 1; shift; shift >>= 1) {
uint32_t tmp = val >> shift;
if (tmp) {
val = tmp;
} else {
......
......@@ -36,7 +36,7 @@ namespace experimental {
namespace detail {
BackendSet GetTensorBackendSet(const phi::TensorBase& t);
std::size_t CountLeadingZeros(uint64_t val);
std::size_t CountLeadingZeros(uint32_t val);
} // namespace detail
phi::DeviceContext* GetDeviceContextByBackend(phi::Backend backend);
......@@ -56,7 +56,7 @@ struct KernelKeySet {
// TODO(chenweihang): iterate all kernelkey for kernel selection
phi::KernelKey GetHighestPriorityKernelKey() {
return phi::KernelKey(static_cast<Backend>(64 - detail::CountLeadingZeros(
return phi::KernelKey(static_cast<Backend>(32 - detail::CountLeadingZeros(
backend_set.bitset())),
layout,
dtype);
......@@ -184,7 +184,7 @@ template <typename T, typename... Args>
Backend ParseBackend(T t, Args... args) {
auto backend_set =
BackendSet(ParseBackend(t)) | BackendSet(ParseBackend(args...));
return static_cast<Backend>(64 -
return static_cast<Backend>(32 -
detail::CountLeadingZeros(backend_set.bitset()));
}
Backend ParseBackendWithInputOrder(const Place& place, const Tensor& tensor);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册