From 53d5abe37224658b637c9db5200e7b4e0b7f949b Mon Sep 17 00:00:00 2001 From: limingshu <61349199+JamesLim-sy@users.noreply.github.com> Date: Fri, 1 Jul 2022 14:23:21 +0800 Subject: [PATCH] Addition of switch_auto_tune option for transpose op (#43310) * 2nd part of transpose update * add switch_auto_tune option. * add some changes according to Ci * refine the structure of auto_tune_base. * merge develop changes * reset the switch_set_range and change unittest of transpose auto-tune * change the kernel auto-tune logits --- paddle/fluid/operators/fused/fmha_ref.h | 17 ++--- .../operators/fused/fused_gate_attention.h | 28 +++---- paddle/fluid/operators/transpose_op.cu.h | 62 +++++---------- paddle/phi/kernels/autotune/auto_tune_base.h | 76 ++++++++++++------- paddle/phi/kernels/autotune/auto_tune_test.cu | 19 ----- paddle/phi/kernels/autotune/cache.cc | 7 ++ paddle/phi/kernels/autotune/cache.h | 4 + .../phi/kernels/autotune/switch_autotune.cc | 1 + paddle/phi/kernels/gpu/transpose_kernel.cu | 3 +- .../tests/unittests/test_transpose_op.py | 35 +++++++++ 10 files changed, 131 insertions(+), 121 deletions(-) diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 3ac57189173..ef1befbb320 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -97,10 +97,9 @@ class FMHARef { // input shape: [bs, seq_len, 3, num_head, head_dim] // transpose with perm [2, 0, 3, 1, 4], // output_shape: [3, bs, num_head, seq_len, head_dim] - int ndims = 5; std::vector perm_1 = {2, 0, 3, 1, 4}; TransposeGPUKernelDriver( - dev_ctx_, ndims, qkv_input_tensor, perm_1, transpose_2_out_tensor); + dev_ctx_, qkv_input_tensor, perm_1, transpose_2_out_tensor); T* qkv_data = transpose_2_out_tensor->data(); T* qk_out_data = qk_out_tensor->data(); T* qktv_out_data = qktv_out_tensor->data(); @@ -255,9 +254,8 @@ class FMHARef { // transpose: [0, 2, 1, 3] // output shape: [batch_size, seq_len, num_heads, head_dim] std::vector perm_3 = {0, 2, 1, 3}; - ndims = 4; TransposeGPUKernelDriver( - dev_ctx_, ndims, *qktv_out_tensor, perm_3, fmha_out_tensor); + dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); } void ComputeBackward(const Tensor& transpose_2_out_tensor, @@ -297,10 +295,9 @@ class FMHARef { T* qktv_out_grad_data = qktv_out_grad_tensor->data(); // transpose bw - int ndims = 4; std::vector perm_3 = {0, 2, 1, 3}; TransposeGPUKernelDriver( - dev_ctx_, ndims, fmha_out_grad_tensor, perm_3, qktv_out_grad_tensor); + dev_ctx_, fmha_out_grad_tensor, perm_3, qktv_out_grad_tensor); // recall batchedgemm(nn) fw: softmax_out_data(x) * v_ptr(y) = // qktv_out_data(out) @@ -476,13 +473,9 @@ class FMHARef { stride_b); // transpose bw - ndims = 5; std::vector perm_1 = {1, 3, 0, 2, 4}; - TransposeGPUKernelDriver(dev_ctx_, - ndims, - *transpose_2_out_grad_tensor, - perm_1, - qkv_input_grad_tensor); + TransposeGPUKernelDriver( + dev_ctx_, *transpose_2_out_grad_tensor, perm_1, qkv_input_grad_tensor); } private: diff --git a/paddle/fluid/operators/fused/fused_gate_attention.h b/paddle/fluid/operators/fused/fused_gate_attention.h index 2dd923bd64d..45d47908b99 100644 --- a/paddle/fluid/operators/fused/fused_gate_attention.h +++ b/paddle/fluid/operators/fused/fused_gate_attention.h @@ -622,11 +622,10 @@ class FMHAGateRef { Tensor* q_transpose_out, Tensor* k_transpose_out, Tensor* v_transpose_out) { - int ndims = 5; std::vector perm = {0, 1, 3, 2, 4}; - TransposeGPUKernelDriver(dev_ctx_, ndims, q_out, perm, q_transpose_out); - TransposeGPUKernelDriver(dev_ctx_, ndims, k_out, perm, k_transpose_out); - TransposeGPUKernelDriver(dev_ctx_, ndims, v_out, perm, v_transpose_out); + TransposeGPUKernelDriver(dev_ctx_, q_out, perm, q_transpose_out); + TransposeGPUKernelDriver(dev_ctx_, k_out, perm, k_transpose_out); + TransposeGPUKernelDriver(dev_ctx_, v_out, perm, v_transpose_out); } void ComputeQKVTransposeBackward(const Tensor& q_transpose_out_grad, @@ -635,48 +634,41 @@ class FMHAGateRef { Tensor* q_out_grad, Tensor* k_out_grad, Tensor* v_out_grad) { - int ndims = 5; std::vector perm = {0, 1, 3, 2, 4}; TransposeGPUKernelDriver( - dev_ctx_, ndims, q_transpose_out_grad, perm, q_out_grad); + dev_ctx_, q_transpose_out_grad, perm, q_out_grad); TransposeGPUKernelDriver( - dev_ctx_, ndims, k_transpose_out_grad, perm, k_out_grad); + dev_ctx_, k_transpose_out_grad, perm, k_out_grad); TransposeGPUKernelDriver( - dev_ctx_, ndims, v_transpose_out_grad, perm, v_out_grad); + dev_ctx_, v_transpose_out_grad, perm, v_out_grad); } // [batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim] -> // [3, batch_size, seq_len_m, num_heads, seq_len_r, head_dim] void ComputeQKVTransposeForward(const Tensor& qkv_out, Tensor* qkv_transpose_out) { - int ndims = 6; std::vector perm = {3, 0, 1, 4, 2, 5}; - TransposeGPUKernelDriver( - dev_ctx_, ndims, qkv_out, perm, qkv_transpose_out); + TransposeGPUKernelDriver(dev_ctx_, qkv_out, perm, qkv_transpose_out); } void ComputeQKVTransposeBackward(const Tensor& qkv_transpose_out_grad, Tensor* qkv_out_grad) { - int ndims = 6; std::vector perm = {1, 2, 4, 0, 3, 5}; TransposeGPUKernelDriver( - dev_ctx_, ndims, qkv_transpose_out_grad, perm, qkv_out_grad); + dev_ctx_, qkv_transpose_out_grad, perm, qkv_out_grad); } // [batch_size, seq_len_m, num_head, seq_len_r, c] -> // [batch_size, seq_len_m, seq_len_r, num_head, c] void ComputeQKTVTransposeForward(const Tensor& qktv_out, Tensor* fmha_out) { - int ndims = 5; std::vector perm = {0, 1, 3, 2, 4}; - TransposeGPUKernelDriver(dev_ctx_, ndims, qktv_out, perm, fmha_out); + TransposeGPUKernelDriver(dev_ctx_, qktv_out, perm, fmha_out); } void ComputeQKTVTransposeBackward(const Tensor& fmha_out_grad, Tensor* qktv_out_grad) { - int ndims = 5; std::vector perm = {0, 1, 3, 2, 4}; - TransposeGPUKernelDriver( - dev_ctx_, ndims, fmha_out_grad, perm, qktv_out_grad); + TransposeGPUKernelDriver(dev_ctx_, fmha_out_grad, perm, qktv_out_grad); } // qk_out = qk_out + nonbatched_bias + src_mask diff --git a/paddle/fluid/operators/transpose_op.cu.h b/paddle/fluid/operators/transpose_op.cu.h index 1b90ad2c313..0ae020c0dfd 100644 --- a/paddle/fluid/operators/transpose_op.cu.h +++ b/paddle/fluid/operators/transpose_op.cu.h @@ -22,7 +22,6 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/autotune/auto_tune_base.h" -#include "paddle/phi/kernels/autotune/cache.h" namespace paddle { namespace operators { @@ -1155,50 +1154,31 @@ inline void SimplifyThenLaunch(const int rank, } template -size_t GetTransposeKey(const int rank, - const Tensor& in, - const std::vector& perm) { - auto in_shape = phi::vectorize(in.dims()); - return phi::autotune::GetKey( - in_shape, perm, rank, paddle::experimental::CppTypeToDataType::Type()); -} - -template -void TransposeGPUKernelDriver(const phi::GPUContext& dev_ctx, - const int rank, +void TransposeGPUKernelDriver(const phi::GPUContext& ctx, const Tensor& in, const std::vector& perm, Tensor* out) { - PADDLE_ENFORCE_LT( - rank, - phi::DDim::kMaxRank, - platform::errors::OutOfRange( - "The maximum dimension rank of " - "tensor is expected to be less than %d, but here is %d.", - phi::DDim::kMaxRank, - rank)); - - auto ret = TransposeSimple::run(dev_ctx, in, perm, out); + const int rank = perm.size(); + auto ret = TransposeSimple::run(ctx, in, perm, out); if (!ret) { - auto* tuner = phi::autotune::MakeTransposeTuner( - SimplifyThenLaunch); - if (!tuner->IsInit()) { - tuner->AddCallBack( - phi::autotune::MakeCallback(TransCompute)); - tuner->Finalize(); - } - - auto key = GetTransposeKey(rank, in, perm); - auto& cache = phi::autotune::AutoTuneCache::Instance().GetTranspose(); - if (cache.Find(key)) { - auto index = cache.Get(key); - tuner->RunBestKernel(index, rank, dev_ctx, in, out, perm); - } else { - // All avaliable kernels have ran while picking the best kernel, so - // there may be no need for another RunBestKernel. - auto index = tuner->PickBestKernel(dev_ctx, rank, dev_ctx, in, out, perm); - cache.Set(key, index); - } + auto* tuner = + phi::autotune::MakeTransposeTuner(TransCompute); + tuner->AddCallBack( + phi::autotune::MakeCallback(SimplifyThenLaunch)); + + size_t key = phi::autotune::TransposeKey( + phi::vectorize(in.dims()), + perm, + paddle::experimental::CppTypeToDataType::Type()); + + tuner->Run(ctx, + phi::autotune::AlgorithmType::kTranspose, + key, + rank, + ctx, + in, + out, + perm); } } diff --git a/paddle/phi/kernels/autotune/auto_tune_base.h b/paddle/phi/kernels/autotune/auto_tune_base.h index 95afa7f697b..91685c2ed54 100644 --- a/paddle/phi/kernels/autotune/auto_tune_base.h +++ b/paddle/phi/kernels/autotune/auto_tune_base.h @@ -14,12 +14,10 @@ #pragma once -#include #include - #include "glog/logging.h" -#include "paddle/phi/core/enforce.h" #include "paddle/phi/kernels/autotune/gpu_timer.h" +#include "paddle/phi/kernels/autotune/switch_autotune.h" namespace phi { namespace autotune { @@ -51,33 +49,61 @@ class AutoTuneBase { public: AutoTuneBase() {} virtual ~AutoTuneBase() {} - explicit AutoTuneBase(KernelType kernel) { kernels_.push_back(kernel); } - template - void AddCallBack(Type kernel) { - static_assert(std::is_same::value, - "Type must be the same"); - kernels_.push_back(kernel); + explicit AutoTuneBase(KernelType kernel) { + kernels_.push_back(/*default=*/kernel); } - template - void RunBestKernel(const int idx, Args&&... args) { - kernels_[idx].Run(args...); + void AddCallBack(KernelType kernel) { + if (!is_init_) { + std::lock_guard lock(mutex_); + kernels_.push_back(kernel); + } } - template - void RunDefaultKernel(Args&&... args) { - kernels_[0].Run(args...); + template + void Run(const Context& ctx, + const AlgorithmType& algo, + const size_t key, + Args&&... args) { + PADDLE_ENFORCE_GT( + kernels_.size(), + 0, + paddle::platform::errors::InvalidArgument( + "kernel num must be greater than 0, now is %d", kernels_.size())); + is_init_ = true; + + auto& cache = AutoTuneCache::Instance().Get(algo); + if (cache.Find(key)) { + auto best_idx = cache.Get(key); + kernels_[best_idx].Run(args...); + } else { + bool use_autotune = AutoTuneStatus::Instance().UseAutoTune(); + if (use_autotune) { + // All avaliable kernels have ran while picking the best kernel, + // so there may be no need for another kernel run. + auto best_idx = PickBestKernel(ctx, args...); + cache.Set(key, best_idx); + } else { + kernels_[0].Run(args...); + } + } } + private: + bool is_init_{false}; + std::vector kernels_; + mutable std::mutex mutex_; + template - int PickBestKernel(const Context& ctx, Args&&... args) { + size_t PickBestKernel(const Context& ctx, Args&&... args) { + std::lock_guard lock(mutex_); PADDLE_ENFORCE_GT( kernels_.size(), 0, paddle::platform::errors::InvalidArgument( "kernel num must be greater than 0, now is %d", kernels_.size())); - int best_idx = 0; + size_t best_idx = 0; float min_time = std::numeric_limits::max(); // Time cost test estabulished in default stream. @@ -92,23 +118,15 @@ class AutoTuneBase { return best_idx; } - bool IsInit() { return is_init_; } - void Finalize() { is_init_ = true; } - - private: - bool is_init_{false}; - std::vector kernels_; - template float RunAndMeasureKernel(const Context& ctx, const int idx, Args&&... args) { + // Regard 1st run as warmup. Judge the result by the time cost of rest run + // cycles. + constexpr int repeats = 3; phi::GpuTimer timer; float time_cost = 0; const auto& stream = ctx.stream(); - // Treat 1st run as warm up. Judge the result with - // the sum of 2nd and 3rd run. - constexpr int repeats = 3; - ctx.Wait(); for (int i = 0; i < repeats; ++i) { timer.Start(stream); @@ -151,7 +169,7 @@ std::once_flag TransposeAutoTuner::init_flag_; template static AutoTuneBase>* - MakeTransposeTuner(RetureType (*func)(Args...)) { +MakeTransposeTuner(RetureType (*func)(Args...)) { auto obj = MakeCallback(func); return TransposeAutoTuner::Instance(obj); } diff --git a/paddle/phi/kernels/autotune/auto_tune_test.cu b/paddle/phi/kernels/autotune/auto_tune_test.cu index d80790dbf2c..2ac7b0b8b75 100644 --- a/paddle/phi/kernels/autotune/auto_tune_test.cu +++ b/paddle/phi/kernels/autotune/auto_tune_test.cu @@ -131,24 +131,5 @@ TEST(AutoTune, sum) { timer.Stop(0); VLOG(3) << "kernel[" << i << "]: time cost is " << timer.ElapsedTime(); } - - // 2. Test call_back tune. - VLOG(3) << ">>> [AutoTune]: Test case."; - auto tuner = tune::MakeAutoTuner(Algo<4>); - tuner.AddCallBack(tune::MakeCallback(Algo<2>)); - tuner.AddCallBack(tune::MakeCallback(Algo<1>)); - - /* The 1st ctx works for ctx.Wait(), - the 2nd is just the param of call_back. */ - auto best_index = tuner.PickBestKernel( - *dev_ctx, *dev_ctx, *d_in1.get(), d_in2.get(), N, threads, blocks); - - dev_ctx->Wait(); - phi::GpuTimer timer; - timer.Start(0); - tuner.RunBestKernel( - best_index, *dev_ctx, *d_in1.get(), d_in2.get(), N, threads, blocks); - timer.Stop(0); - VLOG(3) << "Best CallBackKernel time cost is " << timer.ElapsedTime(); #endif } diff --git a/paddle/phi/kernels/autotune/cache.cc b/paddle/phi/kernels/autotune/cache.cc index 5e2c9e1c742..838f2dd265e 100644 --- a/paddle/phi/kernels/autotune/cache.cc +++ b/paddle/phi/kernels/autotune/cache.cc @@ -36,6 +36,13 @@ size_t ConvKey(const std::vector& x_dims, static_cast(dtype)); } +size_t TransposeKey(const std::vector& x_dims, + const std::vector& perm, + phi::DataType dtype) { + const auto rank = perm.size(); + return GetKey(x_dims, perm, rank, static_cast(dtype)); +} + std::string AlgorithmTypeString(int64_t algo_type) { if (algo_type == static_cast(AlgorithmType::kConvForward)) { return "conv_forward"; diff --git a/paddle/phi/kernels/autotune/cache.h b/paddle/phi/kernels/autotune/cache.h index 8de0695ede4..1263cf40e56 100644 --- a/paddle/phi/kernels/autotune/cache.h +++ b/paddle/phi/kernels/autotune/cache.h @@ -68,6 +68,10 @@ size_t ConvKey(const std::vector& x_dims, const std::vector& dilations, phi::DataType dtype); +size_t TransposeKey(const std::vector& x_dims, + const std::vector& perm, + phi::DataType dtype); + template class AlgorithmsCache { public: diff --git a/paddle/phi/kernels/autotune/switch_autotune.cc b/paddle/phi/kernels/autotune/switch_autotune.cc index 6fda24ef3c8..3742749b3bf 100644 --- a/paddle/phi/kernels/autotune/switch_autotune.cc +++ b/paddle/phi/kernels/autotune/switch_autotune.cc @@ -29,6 +29,7 @@ void AutoTuneStatus::EnableAutoTune() { void AutoTuneStatus::DisableAutoTune() { FLAGS_use_autotune = false; + use_autotune_ = false; Init(); } diff --git a/paddle/phi/kernels/gpu/transpose_kernel.cu b/paddle/phi/kernels/gpu/transpose_kernel.cu index 62e29950e2d..3f3760a4890 100644 --- a/paddle/phi/kernels/gpu/transpose_kernel.cu +++ b/paddle/phi/kernels/gpu/transpose_kernel.cu @@ -31,12 +31,11 @@ void TransposeKernel(const Context& ctx, const DenseTensor& x, const std::vector& axis, DenseTensor* out) { - int rank = axis.size(); ctx.template Alloc(out); if (out->numel() == 0) { return; } - paddle::operators::TransposeGPUKernelDriver(ctx, rank, x, axis, out); + paddle::operators::TransposeGPUKernelDriver(ctx, x, axis, out); } } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_transpose_op.py b/python/paddle/fluid/tests/unittests/test_transpose_op.py index d9e293ba671..fb48f631850 100644 --- a/python/paddle/fluid/tests/unittests/test_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_transpose_op.py @@ -126,6 +126,41 @@ class TestCase9(TestTransposeOp): self.axis = (6, 1, 3, 5, 0, 2, 4, 7) +class TestAutoTuneTransposeOp(OpTest): + + def setUp(self): + self.init_op_type() + self.initTestCase() + self.python_api = paddle.transpose + self.inputs = {'X': np.random.random(self.shape).astype("float64")} + self.attrs = { + 'axis': list(self.axis), + 'use_mkldnn': self.use_mkldnn, + } + self.outputs = { + 'XShape': np.random.random(self.shape).astype("float64"), + 'Out': self.inputs['X'].transpose(self.axis) + } + + def initTestCase(self): + fluid.core.set_autotune_range(0, 3) + fluid.core.update_autotune_status() + fluid.core.enable_autotune() + self.shape = (1, 12, 256, 1) + self.axis = (0, 3, 2, 1) + + def init_op_type(self): + self.op_type = "transpose2" + self.use_mkldnn = False + + def test_check_output(self): + self.check_output(no_check_set=['XShape'], check_eager=True) + fluid.core.disable_autotune() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', check_eager=True) + + class TestTransposeBF16Op(OpTest): def setUp(self): -- GitLab