// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/phi/backends/gpu/cuda/cuda_graph.h" namespace paddle { namespace platform { using CUDAKernelParams = phi::backends::gpu::CUDAKernelParams; #if CUDA_VERSION < 10010 using cudaStreamCaptureMode = phi::backends::gpu::cudaStreamCaptureMode; #endif using CUDAGraph = phi::backends::gpu::CUDAGraph; using CUDAGraphCaptureModeGuard = phi::backends::gpu::CUDAGraphCaptureModeGuard; template static bool IsBitwiseEqual(const T &x, const T &y) { return std::memcmp(&x, &y, sizeof(T)) == 0; } template struct IsSameKernelHelper; template struct IsSameKernelHelper { private: using FuncArgsTuple = decltype(std::make_tuple(std::declval()...)); template struct Impl { static bool Compare(const CUDAKernelParams ¶ms, const TupleT &args) { using CompareT = typename std::tuple_element::type; if (!IsBitwiseEqual(params.As(IDX), std::get(args))) { return false; } constexpr auto NewIsEnd = (IDX + 1 == std::tuple_size::value); return Impl::Compare(params, args); } }; template struct Impl { static bool Compare(const CUDAKernelParams ¶ms, const TupleT &args) { return true; } }; public: template static bool Compare(const CUDAKernelParams ¶ms, Args... args) { constexpr auto kNumArgs = sizeof...(FuncArgs); static_assert(kNumArgs == sizeof...(Args), "Argument number not match"); auto args_tuple = std::make_tuple(args...); using TupleT = typename std::decay::type; return Impl::Compare(params, args_tuple); } }; } // namespace platform } // namespace paddle