未验证 提交 3fcfcd51 编写于 作者: S sneaxiy 提交者: GitHub

Fix bug of CUDAGraph kernel parameter comparation (#43163)

* fix cuda graph sizeof

* fix tuple type
上级 5ccc49e7
......@@ -60,17 +60,18 @@ template <typename Return, typename... FuncArgs,
Return (*kernel_fn)(FuncArgs...)>
struct IsSameKernelHelper<Return (*)(FuncArgs...), kernel_fn> {
private:
using FuncArgsTuple = decltype(std::make_tuple(std::declval<FuncArgs>()...));
template <typename TupleT, size_t IDX, bool IsEnd /*=false*/>
struct Impl {
static bool Compare(const CUDAKernelParams &params, const TupleT &args) {
using CompareT = typename std::tuple_element<IDX, TupleT>::type;
using CompareT = typename std::tuple_element<IDX, FuncArgsTuple>::type;
if (!IsBitwiseEqual<CompareT>(params.As<CompareT>(IDX),
std::get<IDX>(args))) {
return false;
}
constexpr auto NewIsEnd =
(IDX + 1 == sizeof(std::tuple_size<TupleT>::value));
constexpr auto NewIsEnd = (IDX + 1 == std::tuple_size<TupleT>::value);
return Impl<TupleT, IDX + 1, NewIsEnd>::Compare(params, args);
}
};
......@@ -83,8 +84,6 @@ struct IsSameKernelHelper<Return (*)(FuncArgs...), kernel_fn> {
};
public:
using FuncArgsTuple = decltype(std::make_tuple(std::declval<FuncArgs>()...));
template <typename... Args>
static bool Compare(const CUDAKernelParams &params, Args... args) {
constexpr auto kNumArgs = sizeof...(FuncArgs);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册