conv2d_util.h 2.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <cuda_fp16.h>
#include <vector>
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h"

20 21
#include "glog/logging.h"

22 23 24 25 26 27
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"

#include "cutlass/conv/device/implicit_gemm_convolution.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
28
#include "paddle/phi/common/memory_utils.h"
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
#include "paddle/phi/core/enforce.h"

namespace phi {
namespace fusion {
namespace cutlass_internal {
#define CUTLASS_CHECK(status)                                                \
  if (status != cutlass::Status::kSuccess) {                                 \
    VLOG(3)                                                                  \
        << "Cutlass can not deal with this problem size, skip this kernel!"; \
    return status;                                                           \
  }

typedef enum {
  CONV2D_BIAS,
  CONV2D_BIAS_RELU,
  CONV2D_BIAS_ADD_RELU,
  CONV2D_BIAS_SILU,
46 47
  CONV2D_BIAS_LEAKY_RELU,
  CONV2D_BIAS_SILU_ADD
48 49 50 51
} OpType;

// conv2d_diff_gpu calculate diff of cutlass output and baseline output, you can
// use them to debug. return value is the max diff between cutlass and baseline.
52
float conv2d_diff_gpu(const ConvAllParams& params, OpType op_type);
53 54 55

int ProfileToGetBestConfig(
    const std::vector<std::function<cutlass::Status(ConvAllParams)>>& all_func,
56
    const ConvAllParams& params,
57 58 59 60 61
    OpType op_type);

}  // namespace cutlass_internal
}  // namespace fusion
}  // namespace phi