diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index d140912aa783047ba021be171805adff071bf22b..59540dbaefdd81ace1ca232a1c54ba68fe953562 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -62,3 +62,6 @@ register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS $ add_subdirectory(sparse) copy_if_different(${kernel_declare_file} ${kernel_declare_file_final}) + +# 5. kernel autotune +add_subdirectory(autotune) diff --git a/paddle/phi/kernels/autotune/CMakeLists.txt b/paddle/phi/kernels/autotune/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c7bb30d2d767cfc712fc19152f35bb406a89eac9 --- /dev/null +++ b/paddle/phi/kernels/autotune/CMakeLists.txt @@ -0,0 +1,5 @@ +if (WITH_GPU) + nv_test(gpu_timer_test SRCS gpu_timer_test.cu DEPS gtest) +elseif (WITH_ROCM) + hip_test(gpu_timer_test SRCS gpu_timer_test.cu DEPS gtest) +endif() diff --git a/paddle/phi/kernels/autotune/gpu_timer.h b/paddle/phi/kernels/autotune/gpu_timer.h new file mode 100644 index 0000000000000000000000000000000000000000..87eca2613a7b5290341b448e6910ddbbcc833325 --- /dev/null +++ b/paddle/phi/kernels/autotune/gpu_timer.h @@ -0,0 +1,88 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/gpu/gpu_decls.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" +#ifdef PADDLE_WITH_CUDA +#include +#endif +#ifdef PADDLE_WITH_HIP +#include +#endif + +namespace phi { + +class GpuTimer { + public: + GpuTimer() { +#ifdef PADDLE_WITH_HIP + hipEventCreate(&start_); + hipEventCreate(&stop_); +#else + cudaEventCreate(&start_); + cudaEventCreate(&stop_); +#endif + PADDLE_ENFORCE_NOT_NULL( + start_, phi::errors::PreconditionNotMet("Start Event is not ready.")); + PADDLE_ENFORCE_NOT_NULL( + stop_, phi::errors::PreconditionNotMet("Stop Event is not ready.")); + } + + ~GpuTimer() { +#ifdef PADDLE_WITH_HIP + hipEventDestroy(start_); + hipEventDestroy(stop_); +#else + cudaEventDestroy(start_); + cudaEventDestroy(stop_); +#endif + } + + void Start(gpuStream_t stream) { +#ifdef PADDLE_WITH_HIP + hipEventRecord(start_, stream); +#else + cudaEventRecord(start_, stream); +#endif + } + + void Stop(gpuStream_t stream) { +#ifdef PADDLE_WITH_HIP + hipEventRecord(stop_, stream); +#else + cudaEventRecord(stop_, stream); +#endif + } + + float ElapsedTime() { + float milliseconds = 0; +#ifdef PADDLE_WITH_HIP + hipEventSynchronize(stop_); + hipEventElapsedTime(&milliseconds, start_, stop_); +#else + cudaEventSynchronize(stop_); + cudaEventElapsedTime(&milliseconds, start_, stop_); +#endif + return milliseconds; + } + + private: + gpuEvent_t start_; + gpuEvent_t stop_; +}; + +} // namespace phi diff --git a/paddle/phi/kernels/autotune/gpu_timer_test.cu b/paddle/phi/kernels/autotune/gpu_timer_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..b6eb345885f30e2c0ab2406b65bbe5f2d01f944e --- /dev/null +++ b/paddle/phi/kernels/autotune/gpu_timer_test.cu @@ -0,0 +1,117 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "glog/logging.h" +#include "paddle/phi/kernels/autotune/gpu_timer.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" + +template +__global__ void VecSum(T *x, T *y, int N) { +#ifdef __HIPCC__ + int idx = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x; +#else + int idx = blockDim.x * blockIdx.x + threadIdx.x; +#endif + using LoadT = phi::AlignedVector; + for (int i = idx * VecSize; i < N; i += blockDim.x * gridDim.x * VecSize) { + LoadT x_vec; + LoadT y_vec; + phi::Load(&x[i], &x_vec); + phi::Load(&y[i], &y_vec); +#pragma unroll + for (int j = 0; j < VecSize; j++) { + y_vec[j] = x_vec[j] + y_vec[j]; + } + phi::Store(y_vec, &y[i]); + } +} + +template +void Algo(float *d_in, float *d_out, size_t N) { +#ifdef __HIPCC__ + hipLaunchKernelGGL(HIP_KERNEL_NAME(VecSum), + dim3(Blocks), + dim3(Threads), + 0, + 0, + d_in, + d_out, + N); +#else + VecSum<<>>(d_in, d_out, N); +#endif +} + +TEST(GpuTimer, Sum) { + float *in1, *in2, *out; + float *d_in1, *d_in2; + size_t N = 1 << 20; + size_t size = sizeof(float) * N; +#ifdef __HIPCC__ + hipMalloc(reinterpret_cast(&d_in1), size); + hipMalloc(reinterpret_cast(&d_in2), size); +#else + cudaMalloc(reinterpret_cast(&d_in1), size); + cudaMalloc(reinterpret_cast(&d_in2), size); +#endif + in1 = reinterpret_cast(malloc(size)); + in2 = reinterpret_cast(malloc(size)); + out = reinterpret_cast(malloc(size)); + for (size_t i = 0; i < N; i++) { + in1[i] = 1.0f; + in2[i] = 2.0f; + } + +#ifdef __HIPCC__ + hipMemcpy(d_in1, in1, size, hipMemcpyHostToDevice); + hipMemcpy(d_in2, in2, size, hipMemcpyHostToDevice); +#else + cudaMemcpy(d_in1, in1, size, cudaMemcpyHostToDevice); + cudaMemcpy(d_in2, in2, size, cudaMemcpyHostToDevice); +#endif + + using Functor = std::function; + Functor alog0 = Algo<4, 256, 1024>; + Functor algo1 = Algo<1, 256, 1024>; + Functor alog2 = Algo<1, 256, 8>; + + std::vector algos = {alog0, algo1, alog2}; + + for (int j = 0; j < algos.size(); ++j) { + auto algo = algos[j]; + phi::GpuTimer timer; + timer.Start(0); + algo(d_in1, d_in2, N); + timer.Stop(0); + VLOG(3) << "alog: " << j << " cost: " << timer.ElapsedTime() << "ms"; + } + +#ifdef __HIPCC__ + hipMemcpy(out, d_in2, size, hipMemcpyDeviceToHost); +#else + cudaMemcpy(out, d_in2, size, cudaMemcpyDeviceToHost); +#endif + free(in1); + free(in2); + free(out); +#ifdef __HIPCC__ + hipFree(d_in1); + hipFree(d_in2); +#else + cudaFree(d_in1); + cudaFree(d_in2); +#endif +}